Move edge cases to gemm ukr; more user-custom mods. (#583)

Details:
- Moved edge-case handling into the gemm microkernel. This required
  changing the microkernel API to take m and n dimension parameters.
  This required updating all existing gemm microkernel function pointer
  types, function signatures, and related definitions to take m and n
  dimensions. We also updated all existing kernels in the 'kernels' 
  directory to take m and n dimensions, and implemented edge-case 
  handling within those microkernels via a collection of new C 
  preprocessor macros defined within bli_edge_case_macro_defs.h. Also
  removed the assembly code that formerly would handle general stride 
  IO on the microtile, since this can now be handled by the same code
  that does edge cases.
- Pass the obj_t.ker_fn (of matrix C) into bli_gemm_cntl_create() and
  bli_trsm_cntl_create(), where this function pointer is used in lieu of 
  the default macrokernel when it is non-NULL, and ignored when it is
  NULL.
- Re-implemented macrokernel in bli_gemm_ker_var2.c to be a single
  function using byte pointers rather that one function for each
  floating-point datatype. Also, obtain the microkernel function pointer
  from the .ukr field of the params struct embedded within the obj_t
  for matrix C (assuming params is non-NULL and contains a non-NULL
  value in the .ukr field). Communicate both the gemm microkernel
  pointer to use as well as the params struct to the microkernel via
  the auxinfo_t struct.
- Defined gemm_ker_params_t type (for the aforementioned obj_t.params 
  struct) in bli_gemm_var.h.
- Retired the separate _md macrokernel for mixed datatype computation.
  We now use the reimplemented bli_gemm_ker_var2() instead.
- Updated gemmt macrokernels to pass m and n dimensions into microkernel
  calls.
- Removed edge-case handling from trmm and trsm macrokernels.
- Moved most of bli_packm_alloc() code into a new helper function,
  bli_packm_alloc_ex().
- Fixed a typo bug in bli_gemmtrsm_u_template_noopt_mxn.c.
- Added test/syrk_diagonal and test/tensor_contraction directories with
  associated code to test those operations.
This commit is contained in:
Devin Matthews
2021-12-24 08:00:33 -06:00
committed by GitHub
parent 961d9d509d
commit 54fa28bd84
87 changed files with 10980 additions and 14028 deletions

View File

@@ -37,6 +37,8 @@
void bli_zgemm_template_noopt void bli_zgemm_template_noopt
( (
dim_t m,
dim_t n,
dim_t k, dim_t k,
dcomplex* restrict alpha, dcomplex* restrict alpha,
dcomplex* restrict a1, dcomplex* restrict a1,
@@ -88,8 +90,7 @@ void bli_zgemm_template_noopt
dim_t l, j, i; dim_t l, j, i;
dcomplex ab[ bli_zmr * dcomplex ab[ mr * nr ];
bli_znr ];
dcomplex* abij; dcomplex* abij;
dcomplex ai, bj; dcomplex ai, bj;
@@ -137,16 +138,16 @@ void bli_zgemm_template_noopt
if ( bli_zeq0( *beta ) ) if ( bli_zeq0( *beta ) )
{ {
/* c11 := ab */ /* c11 := ab */
bli_zcopys_mxn( mr, bli_zcopys_mxn( m,
nr, n,
ab, rs_ab, cs_ab, ab, rs_ab, cs_ab,
c11, rs_c, cs_c ); c11, rs_c, cs_c );
} }
else else
{ {
/* c11 := beta * c11 + ab */ /* c11 := beta * c11 + ab */
bli_zxpbys_mxn( mr, bli_zxpbys_mxn( m,
nr, n,
ab, rs_ab, cs_ab, ab, rs_ab, cs_ab,
beta, beta,
c11, rs_c, cs_c ); c11, rs_c, cs_c );

View File

@@ -74,6 +74,8 @@ void bli_zgemmtrsm_l_template_noopt
*/ */
const num_t dt = BLIS_DCOMPLEX; const num_t dt = BLIS_DCOMPLEX;
const inc_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx );
const inc_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx );
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx );
const inc_t rs_b = packnr; const inc_t rs_b = packnr;
@@ -84,6 +86,8 @@ void bli_zgemmtrsm_l_template_noopt
/* b11 = alpha * b11 - a10 * b01; */ /* b11 = alpha * b11 - a10 * b01; */
bli_zgemm_template_noopt bli_zgemm_template_noopt
( (
mr,
nr,
k, k,
minus_one, minus_one,
a10, a10,

View File

@@ -74,6 +74,8 @@ void bli_zgemmtrsm_u_template_noopt
*/ */
const num_t dt = BLIS_DCOMPLEX; const num_t dt = BLIS_DCOMPLEX;
const inc_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx );
const inc_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx );
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx );
const inc_t rs_b = packnr; const inc_t rs_b = packnr;
@@ -84,10 +86,12 @@ void bli_zgemmtrsm_u_template_noopt
/* b11 = alpha * b11 - a12 * b21; */ /* b11 = alpha * b11 - a12 * b21; */
bli_zgemm_template_noopt bli_zgemm_template_noopt
( (
mr,
nr,
k, k,
minus_one, minus_one,
a12, a10,
b21, b01,
alpha, alpha,
b11, rs_b, cs_b, b11, rs_b, cs_b,
data data

View File

@@ -36,16 +36,35 @@
#include "blis.h" #include "blis.h"
void* bli_packm_alloc void* bli_packm_alloc
( (
siz_t size_needed, siz_t size_needed,
rntm_t* rntm, rntm_t* rntm,
cntl_t* cntl, cntl_t* cntl,
thrinfo_t* thread thrinfo_t* thread
) )
{ {
// Query the pack buffer type from the control tree node. // Query the pack buffer type from the control tree node.
packbuf_t pack_buf_type = bli_cntl_packm_params_pack_buf_type( cntl ); packbuf_t pack_buf_type = bli_cntl_packm_params_pack_buf_type( cntl );
return bli_packm_alloc_ex
(
size_needed,
pack_buf_type,
rntm,
cntl,
thread
);
}
void* bli_packm_alloc_ex
(
siz_t size_needed,
packbuf_t pack_buf_type,
rntm_t* rntm,
cntl_t* cntl,
thrinfo_t* thread
)
{
// Query the address of the mem_t entry within the control tree node. // Query the address of the mem_t entry within the control tree node.
mem_t* cntl_mem_p = bli_cntl_pack_mem( cntl ); mem_t* cntl_mem_p = bli_cntl_pack_mem( cntl );
@@ -55,7 +74,7 @@ void* bli_packm_alloc
siz_t cntl_mem_size = 0; siz_t cntl_mem_size = 0;
if ( bli_mem_is_alloc( cntl_mem_p ) ) if ( bli_mem_is_alloc( cntl_mem_p ) )
cntl_mem_size = bli_mem_size( cntl_mem_p ); cntl_mem_size = bli_mem_size( cntl_mem_p );
if ( cntl_mem_size < size_needed ) if ( cntl_mem_size < size_needed )
{ {
@@ -64,14 +83,15 @@ void* bli_packm_alloc
// The chief thread releases the existing block associated with // The chief thread releases the existing block associated with
// the mem_t entry in the control tree, and then re-acquires a // the mem_t entry in the control tree, and then re-acquires a
// new block, saving the associated mem_t entry to local_mem_s. // new block, saving the associated mem_t entry to local_mem_s.
if ( bli_mem_is_alloc( cntl_mem_p ) ) if ( bli_mem_is_alloc( cntl_mem_p ) )
{ {
bli_pba_release bli_pba_release
( (
rntm, rntm,
cntl_mem_p cntl_mem_p
); );
} }
bli_pba_acquire_m bli_pba_acquire_m
( (
rntm, rntm,
@@ -89,11 +109,11 @@ void* bli_packm_alloc
// this thread's control tree node. // this thread's control tree node.
*cntl_mem_p = *local_mem_p; *cntl_mem_p = *local_mem_p;
// Barrier so that the master thread doesn't return from the function // Barrier so that the master thread doesn't return from the function
// before we are done reading. // before we are done reading.
bli_thread_barrier( thread ); bli_thread_barrier( thread );
} }
return bli_mem_buffer( cntl_mem_p ); return bli_mem_buffer( cntl_mem_p );
} }

View File

@@ -32,11 +32,20 @@
*/ */
BLIS_EXPORT_BLIS void* bli_packm_alloc BLIS_EXPORT_BLIS void* bli_packm_alloc
( (
siz_t size_needed, siz_t size_needed,
rntm_t* rntm, rntm_t* rntm,
cntl_t* cntl, cntl_t* cntl,
thrinfo_t* thread thrinfo_t* thread
); );
BLIS_EXPORT_BLIS void* bli_packm_alloc_ex
(
siz_t size_needed,
packbuf_t pack_buf_type,
rntm_t* rntm,
cntl_t* cntl,
thrinfo_t* thread
);

View File

@@ -57,7 +57,14 @@ void bli_l3_cntl_create_if
family == BLIS_GEMMT || family == BLIS_GEMMT ||
family == BLIS_TRMM ) family == BLIS_TRMM )
{ {
*cntl_use = bli_gemm_cntl_create( rntm, family, schema_a, schema_b ); *cntl_use = bli_gemm_cntl_create
(
rntm,
family,
schema_a,
schema_b,
bli_obj_ker_fn( c )
);
} }
else // if ( family == BLIS_TRSM ) else // if ( family == BLIS_TRSM )
{ {
@@ -66,7 +73,14 @@ void bli_l3_cntl_create_if
if ( bli_obj_is_triangular( a ) ) side = BLIS_LEFT; if ( bli_obj_is_triangular( a ) ) side = BLIS_LEFT;
else side = BLIS_RIGHT; else side = BLIS_RIGHT;
*cntl_use = bli_trsm_cntl_create( rntm, side, schema_a, schema_b ); *cntl_use = bli_trsm_cntl_create
(
rntm,
side,
schema_a,
schema_b,
bli_obj_ker_fn( c )
);
} }
} }
else else

View File

@@ -47,6 +47,8 @@
\ \
typedef void (*PASTECH3(ch,opname,_ukr,tsuf)) \ typedef void (*PASTECH3(ch,opname,_ukr,tsuf)) \
( \ ( \
dim_t m, \
dim_t n, \
dim_t k, \ dim_t k, \
ctype* restrict alpha, \ ctype* restrict alpha, \
ctype* restrict a, \ ctype* restrict a, \

View File

@@ -51,6 +51,8 @@ void PASTEMAC0(opname) \
\ \
num_t dt = bli_obj_dt( c ); \ num_t dt = bli_obj_dt( c ); \
\ \
dim_t m = bli_obj_length( c ); \
dim_t n = bli_obj_width( c ); \
dim_t k = bli_obj_width( a ); \ dim_t k = bli_obj_width( a ); \
void* buf_a = bli_obj_buffer_at_off( a ); \ void* buf_a = bli_obj_buffer_at_off( a ); \
void* buf_b = bli_obj_buffer_at_off( b ); \ void* buf_b = bli_obj_buffer_at_off( b ); \
@@ -75,6 +77,8 @@ void PASTEMAC0(opname) \
\ \
f \ f \
( \ ( \
m, \
n, \
k, \ k, \
buf_alpha, \ buf_alpha, \
buf_a, \ buf_a, \

View File

@@ -42,6 +42,8 @@
\ \
void PASTEMAC(ch,opname) \ void PASTEMAC(ch,opname) \
( \ ( \
dim_t m, \
dim_t n, \
dim_t k, \ dim_t k, \
ctype_out* restrict alpha, \ ctype_out* restrict alpha, \
ctype_in* restrict a, \ ctype_in* restrict a, \

View File

@@ -39,6 +39,8 @@
\ \
void PASTEMAC(ch,opname) \ void PASTEMAC(ch,opname) \
( \ ( \
dim_t m, \
dim_t n, \
dim_t k, \ dim_t k, \
ctype* restrict alpha, \ ctype* restrict alpha, \
ctype* restrict a, \ ctype* restrict a, \
@@ -58,16 +60,19 @@ void PASTEMAC(ch,opname) \
PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \ PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \
\ \
/* Invoke the typed function for the given datatype. */ \ /* Invoke the typed function for the given datatype. */ \
f( \ f \
k, \ ( \
alpha, \ m, \
a, \ n, \
b, \ k, \
beta, \ alpha, \
c, rs_c, cs_c, \ a, \
data, \ b, \
cntx \ beta, \
); \ c, rs_c, cs_c, \
data, \
cntx \
); \
} \ } \
INSERT_GENTFUNC_BASIC2( gemm_ukernel, gemm, BLIS_GEMM_UKR ) INSERT_GENTFUNC_BASIC2( gemm_ukernel, gemm, BLIS_GEMM_UKR )
@@ -98,17 +103,18 @@ void PASTEMAC(ch,opname) \
PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \ PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \
\ \
/* Invoke the typed function for the given datatype. */ \ /* Invoke the typed function for the given datatype. */ \
f( \ f \
k, \ ( \
alpha, \ k, \
a1x, \ alpha, \
a11, \ a1x, \
bx1, \ a11, \
b11, \ bx1, \
c11, rs_c, cs_c, \ b11, \
data, \ c11, rs_c, cs_c, \
cntx \ data, \
); \ cntx \
); \
} \ } \
INSERT_GENTFUNC_BASIC2( gemmtrsm_l_ukernel, gemmtrsm, BLIS_GEMMTRSM_L_UKR ) INSERT_GENTFUNC_BASIC2( gemmtrsm_l_ukernel, gemmtrsm, BLIS_GEMMTRSM_L_UKR )
@@ -136,13 +142,14 @@ void PASTEMAC(ch,opname) \
PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \ PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \
\ \
/* Invoke the typed function for the given datatype. */ \ /* Invoke the typed function for the given datatype. */ \
f( \ f \
a, \ ( \
b, \ a, \
c, rs_c, cs_c, \ b, \
data, \ c, rs_c, cs_c, \
cntx \ data, \
); \ cntx \
); \
} \ } \
INSERT_GENTFUNC_BASIC2( trsm_l_ukernel, trsm, BLIS_TRSM_L_UKR ) INSERT_GENTFUNC_BASIC2( trsm_l_ukernel, trsm, BLIS_TRSM_L_UKR )

View File

@@ -40,10 +40,11 @@ cntl_t* bli_gemm_cntl_create
rntm_t* rntm, rntm_t* rntm,
opid_t family, opid_t family,
pack_t schema_a, pack_t schema_a,
pack_t schema_b pack_t schema_b,
void_fp ker
) )
{ {
return bli_gemmbp_cntl_create( rntm, family, schema_a, schema_b ); return bli_gemmbp_cntl_create( rntm, family, schema_a, schema_b, ker );
} }
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
@@ -53,18 +54,22 @@ cntl_t* bli_gemmbp_cntl_create
rntm_t* rntm, rntm_t* rntm,
opid_t family, opid_t family,
pack_t schema_a, pack_t schema_a,
pack_t schema_b pack_t schema_b,
void_fp ker
) )
{ {
void_fp macro_kernel_fp; void_fp macro_kernel_fp;
// Use the function pointers to the macrokernels that use slab // Choose the default macrokernel based on the operation family...
// assignment of micropanels to threads in the jr and ir loops.
if ( family == BLIS_GEMM ) macro_kernel_fp = bli_gemm_ker_var2; if ( family == BLIS_GEMM ) macro_kernel_fp = bli_gemm_ker_var2;
else if ( family == BLIS_GEMMT ) macro_kernel_fp = bli_gemmt_x_ker_var2; else if ( family == BLIS_GEMMT ) macro_kernel_fp = bli_gemmt_x_ker_var2;
else if ( family == BLIS_TRMM ) macro_kernel_fp = bli_trmm_xx_ker_var2; else if ( family == BLIS_TRMM ) macro_kernel_fp = bli_trmm_xx_ker_var2;
else /* should never execute */ macro_kernel_fp = NULL; else /* should never execute */ macro_kernel_fp = NULL;
// ...unless a non-NULL kernel function pointer is passed in, in which
// case we use that instead.
if ( ker ) macro_kernel_fp = ker;
// Create two nodes for the macro-kernel. // Create two nodes for the macro-kernel.
cntl_t* gemm_cntl_bu_ke = bli_gemm_cntl_create_node cntl_t* gemm_cntl_bu_ke = bli_gemm_cntl_create_node
( (

View File

@@ -38,7 +38,8 @@ cntl_t* bli_gemm_cntl_create
rntm_t* rntm, rntm_t* rntm,
opid_t family, opid_t family,
pack_t schema_a, pack_t schema_a,
pack_t schema_b pack_t schema_b,
void_fp ker
); );
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
@@ -48,7 +49,8 @@ cntl_t* bli_gemmbp_cntl_create
rntm_t* rntm, rntm_t* rntm,
opid_t family, opid_t family,
pack_t schema_a, pack_t schema_a,
pack_t schema_b pack_t schema_b,
void_fp ker
); );
#if 0 #if 0

View File

@@ -283,90 +283,3 @@ void bli_gemm_front
#endif #endif
} }
// -----------------------------------------------------------------------------
#if 0
if ( bli_obj_dt( a ) != bli_obj_dt( b ) ||
bli_obj_dt( a ) != bli_obj_dt( c ) ||
bli_obj_comp_prec( c ) != bli_obj_prec( c ) )
{
const bool a_is_real = bli_obj_is_real( a );
const bool a_is_comp = bli_obj_is_complex( a );
const bool b_is_real = bli_obj_is_real( b );
const bool b_is_comp = bli_obj_is_complex( b );
const bool c_is_real = bli_obj_is_real( c );
const bool c_is_comp = bli_obj_is_complex( c );
const bool a_is_single = bli_obj_is_single_prec( a );
const bool a_is_double = bli_obj_is_double_prec( a );
const bool b_is_single = bli_obj_is_single_prec( b );
const bool b_is_double = bli_obj_is_double_prec( b );
const bool c_is_single = bli_obj_is_single_prec( c );
const bool c_is_double = bli_obj_is_double_prec( c );
const bool comp_single = bli_obj_comp_prec( c ) == BLIS_SINGLE_PREC;
const bool comp_double = bli_obj_comp_prec( c ) == BLIS_DOUBLE_PREC;
const bool mixeddomain = bli_obj_domain( c ) != bli_obj_domain( a ) ||
bli_obj_domain( c ) != bli_obj_domain( b );
( void )a_is_real; ( void )a_is_comp;
( void )b_is_real; ( void )b_is_comp;
( void )c_is_real; ( void )c_is_comp;
( void )a_is_single; ( void )a_is_double;
( void )b_is_single; ( void )b_is_double;
( void )c_is_single; ( void )c_is_double;
( void )comp_single; ( void )comp_double;
if (
//( c_is_comp && a_is_comp && b_is_real ) ||
//( c_is_comp && a_is_real && b_is_comp ) ||
//( c_is_real && a_is_comp && b_is_comp ) ||
//( c_is_comp && a_is_real && b_is_real ) ||
//( c_is_real && a_is_comp && b_is_real ) ||
//( c_is_real && a_is_real && b_is_comp ) ||
//FALSE
TRUE
)
{
if (
( c_is_single && a_is_single && b_is_single && mixeddomain ) ||
( c_is_single && a_is_single && b_is_single && comp_single ) ||
( c_is_single && a_is_single && b_is_single && comp_double ) ||
( c_is_single && a_is_single && b_is_double ) ||
( c_is_single && a_is_double && b_is_single ) ||
( c_is_double && a_is_single && b_is_single ) ||
( c_is_single && a_is_double && b_is_double ) ||
( c_is_double && a_is_single && b_is_double ) ||
( c_is_double && a_is_double && b_is_single ) ||
( c_is_double && a_is_double && b_is_double && comp_single ) ||
( c_is_double && a_is_double && b_is_double && comp_double ) ||
( c_is_double && a_is_double && b_is_double && mixeddomain ) ||
FALSE
)
bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl );
else
bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl );
}
else
bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl );
return;
}
#else
#if 0
// If any of the storage datatypes differ, or if the execution precision
// differs from the storage precision of C, utilize the mixed datatype
// code path.
// NOTE: We could check the exec dt against the storage dt of C, but for
// now we don't support the caller setting the execution domain
// explicitly.
if ( bli_obj_dt( a ) != bli_obj_dt( b ) ||
bli_obj_dt( a ) != bli_obj_dt( c ) ||
bli_obj_comp_prec( c ) != bli_obj_prec( c ) )
{
bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl );
return;
}
#endif
#endif

View File

@@ -35,28 +35,44 @@
#include "blis.h" #include "blis.h"
#define FUNCPTR_T gemm_fp typedef void (*xpbys_mxn_vft)
(
dim_t m,
dim_t n,
void* x, inc_t rs_x, inc_t cs_x,
void* b,
void* y, inc_t rs_y, inc_t cs_y
);
typedef void (*FUNCPTR_T) #undef GENTFUNC2
( #define GENTFUNC2(ctypex,ctypey,chx,chy,op) \
pack_t schema_a, \
pack_t schema_b, void PASTEMAC2(chx,chy,op) \
dim_t m, ( \
dim_t n, dim_t m, \
dim_t k, dim_t n, \
void* alpha, void* x, inc_t rs_x, inc_t cs_x, \
void* a, inc_t cs_a, inc_t is_a, void* b, \
dim_t pd_a, inc_t ps_a, void* y, inc_t rs_y, inc_t cs_y \
void* b, inc_t rs_b, inc_t is_b, ) \
dim_t pd_b, inc_t ps_b, { \
void* beta, ctypex* restrict x_cast = x; \
void* c, inc_t rs_c, inc_t cs_c, ctypey* restrict b_cast = b; \
cntx_t* cntx, ctypey* restrict y_cast = y; \
rntm_t* rntm, \
thrinfo_t* thread PASTEMAC3(chx,chy,chy,xpbys_mxn) \
); ( \
m, n, \
x_cast, rs_x, cs_x, \
b_cast, \
y_cast, rs_y, cs_y \
); \
}
static FUNCPTR_T GENARRAY(ftypes,gemm_ker_var2); INSERT_GENTFUNC2_BASIC0(xbpys_mxn_fn);
INSERT_GENTFUNC2_MIXDP0(xbpys_mxn_fn);
static xpbys_mxn_vft GENARRAY2_ALL(xbpys_mxn, xbpys_mxn_fn);
void bli_gemm_ker_var2 void bli_gemm_ker_var2
@@ -70,23 +86,8 @@ void bli_gemm_ker_var2
thrinfo_t* thread thrinfo_t* thread
) )
{ {
#ifdef BLIS_ENABLE_GEMM_MD
// By now, A and B have been packed and cast to the execution precision.
// In most cases, such as when storage precision of C differs from the
// execution precision, we utilize the mixed datatype code path. However,
// a few cases still fall within this kernel, such as mixed domain with
// equal precision (ccr, crc, rcc), hence those expressions being disabled
// in the conditional below.
if ( //( bli_obj_domain( c ) != bli_obj_domain( a ) ) ||
//( bli_obj_domain( c ) != bli_obj_domain( b ) ) ||
( bli_obj_dt( c ) != bli_obj_exec_dt( c ) ) )
{
bli_gemm_ker_var2_md( a, b, c, cntx, rntm, cntl, thread );
return;
}
#endif
num_t dt_exec = bli_obj_exec_dt( c ); num_t dt_exec = bli_obj_exec_dt( c );
num_t dt_c = bli_obj_dt( c );
pack_t schema_a = bli_obj_pack_schema( a ); pack_t schema_a = bli_obj_pack_schema( a );
pack_t schema_b = bli_obj_pack_schema( b ); pack_t schema_b = bli_obj_pack_schema( b );
@@ -95,50 +96,55 @@ void bli_gemm_ker_var2
dim_t n = bli_obj_width( c ); dim_t n = bli_obj_width( c );
dim_t k = bli_obj_width( a ); dim_t k = bli_obj_width( a );
void* buf_a = bli_obj_buffer_at_off( a ); char* a_cast = bli_obj_buffer_at_off( a );
inc_t cs_a = bli_obj_col_stride( a );
inc_t is_a = bli_obj_imag_stride( a ); inc_t is_a = bli_obj_imag_stride( a );
dim_t pd_a = bli_obj_panel_dim( a ); dim_t pd_a = bli_obj_panel_dim( a );
inc_t ps_a = bli_obj_panel_stride( a ); inc_t ps_a = bli_obj_panel_stride( a );
void* buf_b = bli_obj_buffer_at_off( b ); char* b_cast = bli_obj_buffer_at_off( b );
inc_t rs_b = bli_obj_row_stride( b );
inc_t is_b = bli_obj_imag_stride( b ); inc_t is_b = bli_obj_imag_stride( b );
dim_t pd_b = bli_obj_panel_dim( b ); dim_t pd_b = bli_obj_panel_dim( b );
inc_t ps_b = bli_obj_panel_stride( b ); inc_t ps_b = bli_obj_panel_stride( b );
void* buf_c = bli_obj_buffer_at_off( c ); char* c_cast = bli_obj_buffer_at_off( c );
inc_t rs_c = bli_obj_row_stride( c ); inc_t rs_c = bli_obj_row_stride( c );
inc_t cs_c = bli_obj_col_stride( c ); inc_t cs_c = bli_obj_col_stride( c );
obj_t scalar_a; // If any dimension is zero, return immediately.
obj_t scalar_b; if ( bli_zero_dim3( m, n, k ) ) return;
void* buf_alpha;
void* buf_beta;
FUNCPTR_T f;
// Detach and multiply the scalars attached to A and B. // Detach and multiply the scalars attached to A and B.
// NOTE: We know that the internal scalars of A and B are already of the
// target datatypes because the necessary typecasting would have already
// taken place during bli_packm_init().
obj_t scalar_a;
obj_t scalar_b;
bli_obj_scalar_detach( a, &scalar_a ); bli_obj_scalar_detach( a, &scalar_a );
bli_obj_scalar_detach( b, &scalar_b ); bli_obj_scalar_detach( b, &scalar_b );
bli_mulsc( &scalar_a, &scalar_b ); bli_mulsc( &scalar_a, &scalar_b );
// Grab the addresses of the internal scalar buffers for the scalar // Grab the addresses of the internal scalar buffers for the scalar
// merged above and the scalar attached to C. // merged above and the scalar attached to C.
buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); // NOTE: We know that scalar_b is of type dt_exec due to the above code
buf_beta = bli_obj_internal_scalar_buffer( c ); // that casts the scalars of A and B to dt_exec via scalar_a and scalar_b,
// and we know that the internal scalar in C is already of the type dt_c
// due to the casting in the implementation of bli_obj_scalar_attach().
char* alpha_cast = bli_obj_internal_scalar_buffer( &scalar_b );
char* beta_cast = bli_obj_internal_scalar_buffer( c );
// If 1m is being employed on a column- or row-stored matrix with a // If 1m is being employed on a column- or row-stored matrix with a
// real-valued beta, we can use the real domain macro-kernel, which // real-valued beta, we can use the real domain macro-kernel, which
// eliminates a little overhead associated with the 1m virtual // eliminates a little overhead associated with the 1m virtual
// micro-kernel. // micro-kernel.
// Only employ this optimization if the storage datatype of C is
// equal to the execution/computation datatype.
#if 1 #if 1
if ( bli_cntx_method( cntx ) == BLIS_1M ) if ( bli_cntx_method( cntx ) == BLIS_1M )
{ {
bli_gemm_ind_recast_1m_params bli_gemm_ind_recast_1m_params
( (
&dt_exec, &dt_exec,
&dt_c,
schema_a, schema_a,
c, c,
&m, &n, &k, &m, &n, &k,
@@ -151,273 +157,211 @@ void bli_gemm_ker_var2
#ifdef BLIS_ENABLE_GEMM_MD #ifdef BLIS_ENABLE_GEMM_MD
// Tweak parameters in select mixed domain cases (rcc, crc, ccr). // Tweak parameters in select mixed domain cases (rcc, crc, ccr).
bli_gemm_md_ker_var2_recast if ( bli_cntx_method( cntx ) == BLIS_NAT )
( {
&dt_exec, bli_gemm_md_ker_var2_recast
bli_obj_dt( a ), (
bli_obj_dt( b ), &dt_exec,
bli_obj_dt( c ), bli_obj_dt( a ),
&m, &n, &k, bli_obj_dt( b ),
&pd_a, &ps_a, &dt_c,
&pd_b, &ps_b, &m, &n, &k,
c, &pd_a, &ps_a,
&rs_c, &cs_c &pd_b, &ps_b,
); c,
&rs_c, &cs_c
);
}
#endif #endif
// Index into the type combination array to extract the correct siz_t dt_size = bli_dt_size( dt_exec );
// function pointer. siz_t dt_c_size = bli_dt_size( dt_c );
f = ftypes[dt_exec];
// Invoke the function. // Alias some constants to simpler names.
f( schema_a, const dim_t MR = pd_a;
schema_b, const dim_t NR = pd_b;
m, //const dim_t PACKMR = cs_a;
n, //const dim_t PACKNR = rs_b;
k,
buf_alpha,
buf_a, cs_a, is_a,
pd_a, ps_a,
buf_b, rs_b, is_b,
pd_b, ps_b,
buf_beta,
buf_c, rs_c, cs_c,
cntx,
rntm,
thread );
}
// Query the context for the micro-kernel address and cast it to its
// function pointer type.
gemm_ukr_vft gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt_exec, BLIS_GEMM_UKR, cntx );
// Query the params field from the obj_t. If it is non-NULL, grab the ukr
// field of the params struct. If that function pointer is non-NULL, use it
// as our microkernel instead of the default microkernel queried from the
// cntx above.
gemm_ker_params_t* params = bli_obj_ker_params( c );
gemm_ukr_vft user_ukr = params ? params->ukr : NULL;
if ( user_ukr ) gemm_ukr = user_ukr;
// Temporary C buffer for edge cases. Note that the strides of this
// temporary buffer are set so that they match the storage of the
// original C matrix. For example, if C is column-stored, ct will be
// column-stored as well.
char ct[ BLIS_STACK_BUF_MAX_SIZE ]
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE)));
const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt_exec, BLIS_GEMM_UKR, cntx );
const inc_t rs_ct = ( col_pref ? 1 : NR );
const inc_t cs_ct = ( col_pref ? MR : 1 );
char* zero = bli_obj_buffer_for_const( dt_exec, &BLIS_ZERO );
//
// Assumptions/assertions:
// rs_a == 1
// cs_a == PACKMR
// pd_a == MR
// ps_a == stride to next micro-panel of A
// rs_b == PACKNR
// cs_b == 1
// pd_b == NR
// ps_b == stride to next micro-panel of B
// rs_c == (no assumptions)
// cs_c == (no assumptions)
//
// Compute number of primary and leftover components of the m and n
// dimensions.
dim_t n_iter = n / NR;
dim_t n_left = n % NR;
dim_t m_iter = m / MR;
dim_t m_left = m % MR;
if ( n_left ) ++n_iter;
if ( m_left ) ++m_iter;
// Determine some increments used to step through A, B, and C.
inc_t rstep_a = ps_a * dt_size;
inc_t cstep_b = ps_b * dt_size;
inc_t rstep_c = rs_c * MR * dt_c_size;
inc_t cstep_c = cs_c * NR * dt_c_size;
auxinfo_t aux;
// Save the pack schemas of A and B to the auxinfo_t object.
bli_auxinfo_set_schema_a( schema_a, &aux );
bli_auxinfo_set_schema_b( schema_b, &aux );
// Save the imaginary stride of A and B to the auxinfo_t object.
bli_auxinfo_set_is_a( is_a, &aux );
bli_auxinfo_set_is_b( is_b, &aux );
// Save the virtual microkernel address and the params.
bli_auxinfo_set_ukr( gemm_ukr, &aux );
bli_auxinfo_set_params( params, &aux );
// The 'thread' argument points to the thrinfo_t node for the 2nd (jr)
// loop around the microkernel. Here we query the thrinfo_t node for the
// 1st (ir) loop around the microkernel.
thrinfo_t* caucus = bli_thrinfo_sub_node( thread );
// Query the number of threads and thread ids for each loop.
dim_t jr_nt = bli_thread_n_way( thread );
dim_t jr_tid = bli_thread_work_id( thread );
dim_t ir_nt = bli_thread_n_way( caucus );
dim_t ir_tid = bli_thread_work_id( caucus );
dim_t jr_start, jr_end;
dim_t ir_start, ir_end;
dim_t jr_inc, ir_inc;
// Determine the thread range and increment for the 2nd and 1st loops.
// NOTE: The definition of bli_thread_range_jrir() will depend on whether
// slab or round-robin partitioning was requested at configure-time.
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc );
bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc );
// Loop over the n dimension (NR columns at a time).
for ( dim_t j = jr_start; j < jr_end; j += jr_inc )
{
char* b1 = b_cast + j * cstep_b;
char* c1 = c_cast + j * cstep_c;
dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left );
// Initialize our next panel of B to be the current panel of B.
char* b2 = b1;
// Loop over the m dimension (MR rows at a time).
for ( dim_t i = ir_start; i < ir_end; i += ir_inc )
{
char* a1 = a_cast + i * rstep_a;
char* c11 = c1 + i * rstep_c;
dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left );
// Compute the addresses of the next panels of A and B.
char* a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc );
if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) )
{
a2 = a_cast;
b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc );
if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) )
b2 = b_cast;
}
// Save addresses of next panels of A and B to the auxinfo_t
// object.
bli_auxinfo_set_next_a( a2, &aux );
bli_auxinfo_set_next_b( b2, &aux );
// Edge case handling now occurs within the microkernel itself, but
// we must still explicitly accumulate to a temporary microtile in
// situations where a virtual microkernel is being used, such as
// during the 1m method or some cases of mixed datatypes.
if ( dt_exec == dt_c )
{
// Invoke the gemm micro-kernel.
gemm_ukr
(
m_cur,
n_cur,
k,
alpha_cast,
a1,
b1,
beta_cast,
c11, rs_c, cs_c,
&aux,
cntx
);
}
else
{
// Invoke the gemm micro-kernel.
gemm_ukr
(
MR,
NR,
k,
alpha_cast,
a1,
b1,
zero,
&ct, rs_ct, cs_ct,
&aux,
cntx
);
// Accumulate to C with type-casting.
xbpys_mxn[ dt_exec ][ dt_c ]
(
m_cur, n_cur,
&ct, rs_ct, cs_ct,
beta_cast,
c11, rs_c, cs_c
);
}
}
}
#undef GENTFUNC
#define GENTFUNC( ctype, ch, varname ) \
\
void PASTEMAC(ch,varname) \
( \
pack_t schema_a, \
pack_t schema_b, \
dim_t m, \
dim_t n, \
dim_t k, \
void* alpha, \
void* a, inc_t cs_a, inc_t is_a, \
dim_t pd_a, inc_t ps_a, \
void* b, inc_t rs_b, inc_t is_b, \
dim_t pd_b, inc_t ps_b, \
void* beta, \
void* c, inc_t rs_c, inc_t cs_c, \
cntx_t* cntx, \
rntm_t* rntm, \
thrinfo_t* thread \
) \
{ \
const num_t dt = PASTEMAC(ch,type); \
\
/* Alias some constants to simpler names. */ \
const dim_t MR = pd_a; \
const dim_t NR = pd_b; \
/*const dim_t PACKMR = cs_a;*/ \
/*const dim_t PACKNR = rs_b;*/ \
\
/* Query the context for the micro-kernel address and cast it to its
function pointer type. Note that the virtual gemm ukernel is queried
instead of the native gemm ukernel. This is needed for certain
situations for the 1m method that require an extra layer of logic
to allow for handling (for example) complex values of beta. Also
note that under certain circumstances, the real-domain version of
this macrokernel will be called for 1m (NOT the complex version)
as an optimization. In these cases, the corresponding real-domain
slots within the cntx_t's virtual gemm ukernel func_t will contain
pointers to the *native* gemm ukernel, thanks to logic in the
context initialization function for the induced method (defined
in bli_cntx_ref.c). */ \
PASTECH(ch,gemm_ukr_ft) \
gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \
\
/* Temporary C buffer for edge cases. Note that the strides of this
temporary buffer are set so that they match the storage of the
original C matrix. For example, if C is column-stored, ct will be
column-stored as well. */ \
ctype ct[ BLIS_STACK_BUF_MAX_SIZE \
/ sizeof( ctype ) ] \
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
\
ctype* restrict zero = PASTEMAC(ch,0); \
ctype* restrict a_cast = a; \
ctype* restrict b_cast = b; \
ctype* restrict c_cast = c; \
ctype* restrict alpha_cast = alpha; \
ctype* restrict beta_cast = beta; \
ctype* restrict b1; \
ctype* restrict c1; \
\
dim_t m_iter, m_left; \
dim_t n_iter, n_left; \
dim_t i, j; \
dim_t m_cur; \
dim_t n_cur; \
inc_t rstep_a; \
inc_t cstep_b; \
inc_t rstep_c, cstep_c; \
auxinfo_t aux; \
\
/*
Assumptions/assertions:
rs_a == 1
cs_a == PACKMR
pd_a == MR
ps_a == stride to next micro-panel of A
rs_b == PACKNR
cs_b == 1
pd_b == NR
ps_b == stride to next micro-panel of B
rs_c == (no assumptions)
cs_c == (no assumptions)
*/ \
\
/* If any dimension is zero, return immediately. */ \
if ( bli_zero_dim3( m, n, k ) ) return; \
\
/* Clear the temporary C buffer in case it has any infs or NaNs. */ \
PASTEMAC(ch,set0s_mxn)( MR, NR, \
ct, rs_ct, cs_ct ); \
\
/* Compute number of primary and leftover components of the m and n
dimensions. */ \
n_iter = n / NR; \
n_left = n % NR; \
\
m_iter = m / MR; \
m_left = m % MR; \
\
if ( n_left ) ++n_iter; \
if ( m_left ) ++m_iter; \
\
/* Determine some increments used to step through A, B, and C. */ \
rstep_a = ps_a; \
\
cstep_b = ps_b; \
\
rstep_c = rs_c * MR; \
cstep_c = cs_c * NR; \
\
/* Save the pack schemas of A and B to the auxinfo_t object. */ \
bli_auxinfo_set_schema_a( schema_a, &aux ); \
bli_auxinfo_set_schema_b( schema_b, &aux ); \
\
/* Save the imaginary stride of A and B to the auxinfo_t object. */ \
bli_auxinfo_set_is_a( is_a, &aux ); \
bli_auxinfo_set_is_b( is_b, &aux ); \
\
/* The 'thread' argument points to the thrinfo_t node for the 2nd (jr)
loop around the microkernel. Here we query the thrinfo_t node for the
1st (ir) loop around the microkernel. */ \
thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \
\
/* Query the number of threads and thread ids for each loop. */ \
dim_t jr_nt = bli_thread_n_way( thread ); \
dim_t jr_tid = bli_thread_work_id( thread ); \
dim_t ir_nt = bli_thread_n_way( caucus ); \
dim_t ir_tid = bli_thread_work_id( caucus ); \
\
dim_t jr_start, jr_end; \
dim_t ir_start, ir_end; \
dim_t jr_inc, ir_inc; \
\
/* Determine the thread range and increment for the 2nd and 1st loops.
NOTE: The definition of bli_thread_range_jrir() will depend on whether
slab or round-robin partitioning was requested at configure-time. */ \
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \
bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \
\
/* Loop over the n dimension (NR columns at a time). */ \
for ( j = jr_start; j < jr_end; j += jr_inc ) \
{ \
ctype* restrict a1; \
ctype* restrict c11; \
ctype* restrict b2; \
\
b1 = b_cast + j * cstep_b; \
c1 = c_cast + j * cstep_c; \
\
n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \
\
/* Initialize our next panel of B to be the current panel of B. */ \
b2 = b1; \
\
/* Loop over the m dimension (MR rows at a time). */ \
for ( i = ir_start; i < ir_end; i += ir_inc ) \
{ \
ctype* restrict a2; \
\
a1 = a_cast + i * rstep_a; \
c11 = c1 + i * rstep_c; \
\
m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \
\
/* Compute the addresses of the next panels of A and B. */ \
a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); \
if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \
{ \
a2 = a_cast; \
b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); \
if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \
b2 = b_cast; \
} \
\
/* Save addresses of next panels of A and B to the auxinfo_t
object. */ \
bli_auxinfo_set_next_a( a2, &aux ); \
bli_auxinfo_set_next_b( b2, &aux ); \
\
/* Handle interior and edge cases separately. */ \
if ( m_cur == MR && n_cur == NR ) \
{ \
/* Invoke the gemm micro-kernel. */ \
gemm_ukr \
( \
k, \
alpha_cast, \
a1, \
b1, \
beta_cast, \
c11, rs_c, cs_c, \
&aux, \
cntx \
); \
} \
else \
{ \
/* Invoke the gemm micro-kernel. */ \
gemm_ukr \
( \
k, \
alpha_cast, \
a1, \
b1, \
zero, \
ct, rs_ct, cs_ct, \
&aux, \
cntx \
); \
\
/* Scale the bottom edge of C and add the result from above. */ \
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
ct, rs_ct, cs_ct, \
beta_cast, \
c11, rs_c, cs_c ); \
} \
} \
} \
\
/* /*
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \ PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" );
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" ); \ PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" );
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: c after", m_cur, n_cur, c11, rs_c, cs_c, "%4.1f", "" ); \ PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: c after", m_cur, n_cur, c11, rs_c, cs_c, "%4.1f", "" );
*/ \ */
} }
INSERT_GENTFUNC_BASIC0( gemm_ker_var2 )

View File

@@ -1,406 +0,0 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2014, The University of Texas at Austin
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#ifdef BLIS_ENABLE_GEMM_MD
#define FUNCPTR_T gemm_fp
typedef void (*FUNCPTR_T)
(
pack_t schema_a,
pack_t schema_b,
dim_t m,
dim_t n,
dim_t k,
void* alpha,
void* a, inc_t cs_a, inc_t is_a,
dim_t pd_a, inc_t ps_a,
void* b, inc_t rs_b, inc_t is_b,
dim_t pd_b, inc_t ps_b,
void* beta,
void* c, inc_t rs_c, inc_t cs_c,
cntx_t* cntx,
rntm_t* rntm,
thrinfo_t* thread
);
static FUNCPTR_T GENARRAY2_ALL(ftypes,gemm_ker_var2_md);
void bli_gemm_ker_var2_md
(
obj_t* a,
obj_t* b,
obj_t* c,
cntx_t* cntx,
rntm_t* rntm,
cntl_t* cntl,
thrinfo_t* thread
)
{
num_t dt_exec = bli_obj_exec_dt( c );
num_t dt_c = bli_obj_dt( c );
pack_t schema_a = bli_obj_pack_schema( a );
pack_t schema_b = bli_obj_pack_schema( b );
dim_t m = bli_obj_length( c );
dim_t n = bli_obj_width( c );
dim_t k = bli_obj_width( a );
void* buf_a = bli_obj_buffer_at_off( a );
inc_t cs_a = bli_obj_col_stride( a );
inc_t is_a = bli_obj_imag_stride( a );
dim_t pd_a = bli_obj_panel_dim( a );
inc_t ps_a = bli_obj_panel_stride( a );
void* buf_b = bli_obj_buffer_at_off( b );
inc_t rs_b = bli_obj_row_stride( b );
inc_t is_b = bli_obj_imag_stride( b );
dim_t pd_b = bli_obj_panel_dim( b );
inc_t ps_b = bli_obj_panel_stride( b );
void* buf_c = bli_obj_buffer_at_off( c );
inc_t rs_c = bli_obj_row_stride( c );
inc_t cs_c = bli_obj_col_stride( c );
obj_t scalar_a;
obj_t scalar_b;
void* buf_alpha;
void* buf_beta;
FUNCPTR_T f;
// Detach and multiply the scalars attached to A and B.
// NOTE: We know that the internal scalars of A and B are already of the
// target datatypes because the necessary typecasting would have already
// taken place during bli_packm_init().
bli_obj_scalar_detach( a, &scalar_a );
bli_obj_scalar_detach( b, &scalar_b );
bli_mulsc( &scalar_a, &scalar_b );
// Grab the addresses of the internal scalar buffers for the scalar
// merged above and the scalar attached to C.
// NOTE: We know that scalar_b is of type dt_exec due to the above code
// that casts the scalars of A and B to dt_exec via scalar_a and scalar_b,
// and we know that the internal scalar in C is already of the type dt_c
// due to the casting in the implementation of bli_obj_scalar_attach().
buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b );
buf_beta = bli_obj_internal_scalar_buffer( c );
#if 0
// NOTE: Turns out that this optimization will never be employed since
// currently bli_gemm_ker_var2_md() is only called when the storage
// datatype of C differs from the execution/computation datatype, and
// this optimization would only make sense if they are equal.
// If 1m is being employed on a column- or row-stored matrix with a
// real-valued beta, we can use the real domain macro-kernel, which
// eliminates a little overhead associated with the 1m virtual
// micro-kernel.
if ( bli_cntx_method( cntx ) == BLIS_1M )
{
// Only employ this optimization if the storage datatype of C is
// equal to the execution/computation datatype.
if ( dt_c == dt_exec )
{
bli_gemm_ind_recast_1m_params
(
&dt_exec,
schema_a,
c,
&m, &n, &k,
&pd_a, &ps_a,
&pd_b, &ps_b,
&rs_c, &cs_c
);
}
}
#endif
// Tweak parameters in select mixed domain cases (rcc, crc, ccr).
bli_gemm_md_ker_var2_recast
(
&dt_exec,
bli_obj_dt( a ),
bli_obj_dt( b ),
bli_obj_dt( c ),
&m, &n, &k,
&pd_a, &ps_a,
&pd_b, &ps_b,
c,
&rs_c, &cs_c
);
// Index into the type combination array to extract the correct
// function pointer.
f = ftypes[dt_c][dt_exec];
// Invoke the function.
f( schema_a,
schema_b,
m,
n,
k,
buf_alpha,
buf_a, cs_a, is_a,
pd_a, ps_a,
buf_b, rs_b, is_b,
pd_b, ps_b,
buf_beta,
buf_c, rs_c, cs_c,
cntx,
rntm,
thread );
}
#undef GENTFUNC2
#define GENTFUNC2( ctype_c, ctype_e, chc, che, varname ) \
\
void PASTEMAC2(chc,che,varname) \
( \
pack_t schema_a, \
pack_t schema_b, \
dim_t m, \
dim_t n, \
dim_t k, \
void* alpha, \
void* a, inc_t cs_a, inc_t is_a, \
dim_t pd_a, inc_t ps_a, \
void* b, inc_t rs_b, inc_t is_b, \
dim_t pd_b, inc_t ps_b, \
void* beta, \
void* c, inc_t rs_c, inc_t cs_c, \
cntx_t* cntx, \
rntm_t* rntm, \
thrinfo_t* thread \
) \
{ \
const num_t dte = PASTEMAC(che,type); \
/*const num_t dtc = PASTEMAC(chc,type);*/ \
\
/* Alias some constants to simpler names. */ \
const dim_t MR = pd_a; \
const dim_t NR = pd_b; \
/*const dim_t PACKMR = cs_a;*/ \
/*const dim_t PACKNR = rs_b;*/ \
\
/* Query the context for the micro-kernel address and cast it to its
function pointer type. */ \
PASTECH(che,gemm_ukr_ft) \
gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dte, BLIS_GEMM_UKR, cntx ); \
\
/* Temporary C buffer for edge cases. Note that the strides of this
temporary buffer are set so that they match the storage of the
original C matrix. For example, if C is column-stored, ct will be
column-stored as well. */ \
ctype_e ct[ BLIS_STACK_BUF_MAX_SIZE \
/ sizeof( ctype_e ) ] \
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dte, BLIS_GEMM_UKR, cntx ); \
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
\
ctype_e* restrict zero = PASTEMAC(che,0); \
ctype_e* restrict a_cast = a; \
ctype_e* restrict b_cast = b; \
ctype_c* restrict c_cast = c; \
ctype_e* restrict alpha_cast = alpha; \
ctype_c* restrict beta_cast = beta; \
ctype_e* restrict b1; \
ctype_c* restrict c1; \
\
dim_t m_iter, m_left; \
dim_t n_iter, n_left; \
dim_t i, j; \
dim_t m_cur; \
dim_t n_cur; \
inc_t rstep_a; \
inc_t cstep_b; \
inc_t rstep_c, cstep_c; \
auxinfo_t aux; \
\
/*
Assumptions/assertions:
rs_a == 1
cs_a == PACKMR
pd_a == MR
ps_a == stride to next micro-panel of A
rs_b == PACKNR
cs_b == 1
pd_b == NR
ps_b == stride to next micro-panel of B
rs_c == (no assumptions)
cs_c == (no assumptions)
*/ \
\
/* If any dimension is zero, return immediately. */ \
if ( bli_zero_dim3( m, n, k ) ) return; \
\
/* Clear the temporary C buffer in case it has any infs or NaNs. */ \
PASTEMAC(che,set0s_mxn)( MR, NR, \
ct, rs_ct, cs_ct ); \
\
/* Compute number of primary and leftover components of the m and n
dimensions. */ \
n_iter = n / NR; \
n_left = n % NR; \
\
m_iter = m / MR; \
m_left = m % MR; \
\
if ( n_left ) ++n_iter; \
if ( m_left ) ++m_iter; \
\
/* Determine some increments used to step through A, B, and C. */ \
rstep_a = ps_a; \
\
cstep_b = ps_b; \
\
rstep_c = rs_c * MR; \
cstep_c = cs_c * NR; \
\
/* Save the pack schemas of A and B to the auxinfo_t object. */ \
bli_auxinfo_set_schema_a( schema_a, &aux ); \
bli_auxinfo_set_schema_b( schema_b, &aux ); \
\
/* Save the imaginary stride of A and B to the auxinfo_t object. */ \
bli_auxinfo_set_is_a( is_a, &aux ); \
bli_auxinfo_set_is_b( is_b, &aux ); \
\
/* The 'thread' argument points to the thrinfo_t node for the 2nd (jr)
loop around the microkernel. Here we query the thrinfo_t node for the
1st (ir) loop around the microkernel. */ \
thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \
\
/* Query the number of threads and thread ids for each loop. */ \
dim_t jr_nt = bli_thread_n_way( thread ); \
dim_t jr_tid = bli_thread_work_id( thread ); \
dim_t ir_nt = bli_thread_n_way( caucus ); \
dim_t ir_tid = bli_thread_work_id( caucus ); \
\
dim_t jr_start, jr_end; \
dim_t ir_start, ir_end; \
dim_t jr_inc, ir_inc; \
\
/* Determine the thread range and increment for the 2nd and 1st loops.
NOTE: The definition of bli_thread_range_jrir() will depend on whether
slab or round-robin partitioning was requested at configure-time. */ \
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \
bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \
\
/* Loop over the n dimension (NR columns at a time). */ \
for ( j = jr_start; j < jr_end; j += jr_inc ) \
{ \
ctype_e* restrict a1; \
ctype_c* restrict c11; \
ctype_e* restrict b2; \
\
b1 = b_cast + j * cstep_b; \
c1 = c_cast + j * cstep_c; \
\
n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \
\
/* Initialize our next panel of B to be the current panel of B. */ \
b2 = b1; \
\
/* Loop over the m dimension (MR rows at a time). */ \
for ( i = ir_start; i < ir_end; i += ir_inc ) \
{ \
ctype_e* restrict a2; \
\
a1 = a_cast + i * rstep_a; \
c11 = c1 + i * rstep_c; \
\
m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \
\
/* Compute the addresses of the next panels of A and B. */ \
a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); \
if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \
{ \
a2 = a_cast; \
b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); \
if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \
b2 = b_cast; \
} \
\
/* Save addresses of next panels of A and B to the auxinfo_t
object. */ \
bli_auxinfo_set_next_a( a2, &aux ); \
bli_auxinfo_set_next_b( b2, &aux ); \
\
/* Always save the micropanel product to the local microtile and
then accumulate it into C via the xpbys_mxn macro. */ \
/*if ( 1 )*/ \
{ \
/* Invoke the gemm micro-kernel. */ \
gemm_ukr \
( \
k, \
alpha_cast, \
a1, \
b1, \
zero, \
ct, rs_ct, cs_ct, \
&aux, \
cntx \
); \
\
/* Scale the microtile of C and add the result from above. */ \
PASTEMAC3(che,chc,chc,xpbys_mxn) \
( \
m_cur, n_cur, \
ct, rs_ct, cs_ct, \
beta_cast, \
c11, rs_c, cs_c \
); \
} \
} \
} \
\
/*
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" ); \
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: c after", m_cur, n_cur, c11, rs_c, cs_c, "%4.1f", "" ); \
*/ \
}
INSERT_GENTFUNC2_BASIC0( gemm_ker_var2_md )
INSERT_GENTFUNC2_MIXDP0( gemm_ker_var2_md )
#endif

View File

@@ -154,7 +154,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast
num_t* dt_comp, num_t* dt_comp,
num_t dt_a, num_t dt_a,
num_t dt_b, num_t dt_b,
num_t dt_c, num_t* dt_c,
dim_t* m, dim_t* m,
dim_t* n, dim_t* n,
dim_t* k, dim_t* k,
@@ -164,7 +164,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast
inc_t* rs_c, inc_t* cs_c inc_t* rs_c, inc_t* cs_c
) )
{ {
if ( bli_is_real( dt_c ) && if ( bli_is_real( *dt_c ) &&
bli_is_complex( dt_a ) && bli_is_complex( dt_a ) &&
bli_is_complex( dt_b ) ) bli_is_complex( dt_b ) )
{ {
@@ -177,7 +177,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast
*ps_a *= 2; *ps_a *= 2;
*ps_b *= 2; *ps_b *= 2;
} }
else if ( bli_is_complex( dt_c ) && else if ( bli_is_complex( *dt_c ) &&
bli_is_real( dt_a ) && bli_is_real( dt_a ) &&
bli_is_complex( dt_b ) ) bli_is_complex( dt_b ) )
{ {
@@ -197,6 +197,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast
// to the real virtual microkernel slots of the context) instead of // to the real virtual microkernel slots of the context) instead of
// the complex macrokernel and c2r virtual microkernel. // the complex macrokernel and c2r virtual microkernel.
*dt_comp = bli_dt_proj_to_real( *dt_comp ); *dt_comp = bli_dt_proj_to_real( *dt_comp );
*dt_c = bli_dt_proj_to_real( *dt_c );
*n *= 2; *n *= 2;
*pd_b *= 2; *ps_b *= 2; *pd_b *= 2; *ps_b *= 2;
*rs_c *= 2; *rs_c *= 2;
@@ -211,7 +212,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast
*ps_a /= 2; *ps_a /= 2;
} }
} }
else if ( bli_is_complex( dt_c ) && else if ( bli_is_complex( *dt_c ) &&
bli_is_complex( dt_a ) && bli_is_complex( dt_a ) &&
bli_is_real( dt_b ) ) bli_is_real( dt_b ) )
{ {
@@ -231,6 +232,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast
// to the real virtual microkernel slots of the context) instead of // to the real virtual microkernel slots of the context) instead of
// the complex macrokernel and c2r virtual microkernel. // the complex macrokernel and c2r virtual microkernel.
*dt_comp = bli_dt_proj_to_real( *dt_comp ); *dt_comp = bli_dt_proj_to_real( *dt_comp );
*dt_c = bli_dt_proj_to_real( *dt_c );
*m *= 2; *m *= 2;
*pd_a *= 2; *ps_a *= 2; *pd_a *= 2; *ps_a *= 2;
*cs_c *= 2; *cs_c *= 2;
@@ -274,54 +276,3 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast
#endif #endif
} }
// -----------------------------------------------------------------------------
//
// Prototype object-based interfaces.
//
#undef GENPROT
#define GENPROT( opname ) \
\
void PASTEMAC0(opname) \
( \
obj_t* a, \
obj_t* b, \
obj_t* c, \
cntx_t* cntx, \
rntm_t* rntm, \
cntl_t* cntl, \
thrinfo_t* thread \
);
GENPROT( gemm_ker_var2_md )
//
// Prototype BLAS-like interfaces with void pointer operands.
//
#undef GENTPROT2
#define GENTPROT2( ctype_c, ctype_e, chc, che, varname ) \
\
void PASTEMAC2(chc,che,varname) \
( \
pack_t schema_a, \
pack_t schema_b, \
dim_t m, \
dim_t n, \
dim_t k, \
void* alpha, \
void* a, inc_t cs_a, inc_t is_a, \
dim_t pd_a, inc_t ps_a, \
void* b, inc_t rs_b, inc_t is_b, \
dim_t pd_b, inc_t ps_b, \
void* beta, \
void* c, inc_t rs_c, inc_t cs_c, \
cntx_t* cntx, \
rntm_t* rntm, \
thrinfo_t* thread \
);
INSERT_GENTPROT2_BASIC0( gemm_ker_var2_md )
INSERT_GENTPROT2_MIXDP0( gemm_ker_var2_md )

View File

@@ -41,6 +41,8 @@
\ \
void PASTEMAC2(ch,opname,suf) \ void PASTEMAC2(ch,opname,suf) \
( \ ( \
dim_t m, \
dim_t n, \
dim_t k, \ dim_t k, \
ctype* restrict alpha, \ ctype* restrict alpha, \
ctype* restrict a, \ ctype* restrict a, \
@@ -61,6 +63,9 @@ void PASTEMAC2(ch,opname,suf) \
\ \
const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \
const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \
\
dim_t mr_r = mr; \
dim_t nr_r = nr; \
\ \
ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \
/ sizeof( ctype_r ) ] \ / sizeof( ctype_r ) ] \
@@ -81,6 +86,9 @@ void PASTEMAC2(ch,opname,suf) \
\ \
ctype_r* restrict beta_r = &PASTEMAC(ch,real)( *beta ); \ ctype_r* restrict beta_r = &PASTEMAC(ch,real)( *beta ); \
ctype_r* restrict beta_i = &PASTEMAC(ch,imag)( *beta ); \ ctype_r* restrict beta_i = &PASTEMAC(ch,imag)( *beta ); \
\
dim_t m_use; \
dim_t n_use; \
\ \
ctype_r* c_use; \ ctype_r* c_use; \
inc_t rs_c_use; \ inc_t rs_c_use; \
@@ -146,17 +154,16 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \
rs_c_use = rs_ct; \ rs_c_use = rs_ct; \
cs_c_use = cs_ct; \ cs_c_use = cs_ct; \
\ \
/* Convert the strides from being in units of complex elements to /* Convert the strides and corresponding microtile dimension from being
be in units of real elements. Note that we don't need to check for in units of complex elements to be in units of real elements. */ \
general storage here because that case corresponds to the scenario if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) { cs_c_use *= 2; mr_r *= 2; } \
where we are using the ct buffer and its rs_ct/cs_ct strides. */ \ else { rs_c_use *= 2; nr_r *= 2; }\
if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) cs_c_use *= 2; \
else rs_c_use *= 2; \
\
\ \
/* c = beta * c + alpha_r * a * b; */ \ /* c = beta * c + alpha_r * a * b; */ \
rgemm_ukr \ rgemm_ukr \
( \ ( \
mr_r, \
nr_r, \
k, \ k, \
alpha_r, \ alpha_r, \
a_r, \ a_r, \
@@ -166,14 +173,12 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \
data, \ data, \
cntx \ cntx \
); \ ); \
\
dim_t i, j; \
\ \
/* Accumulate the final result in ct back to c. */ \ /* Accumulate the final result in ct back to c. */ \
if ( PASTEMAC(ch,eq1)( *beta ) ) \ if ( PASTEMAC(ch,eq1)( *beta ) ) \
{ \ { \
for ( j = 0; j < nr; ++j ) \ for ( dim_t j = 0; j < n; ++j ) \
for ( i = 0; i < mr; ++i ) \ for ( dim_t i = 0; i < m; ++i ) \
{ \ { \
PASTEMAC(ch,adds)( *(ct + i*rs_ct + j*cs_ct), \ PASTEMAC(ch,adds)( *(ct + i*rs_ct + j*cs_ct), \
*(c + i*rs_c + j*cs_c ) ); \ *(c + i*rs_c + j*cs_c ) ); \
@@ -181,8 +186,8 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \
} \ } \
else if ( PASTEMAC(ch,eq0)( *beta ) ) \ else if ( PASTEMAC(ch,eq0)( *beta ) ) \
{ \ { \
for ( j = 0; j < nr; ++j ) \ for ( dim_t j = 0; j < n; ++j ) \
for ( i = 0; i < mr; ++i ) \ for ( dim_t i = 0; i < m; ++i ) \
{ \ { \
PASTEMAC(ch,copys)( *(ct + i*rs_ct + j*cs_ct), \ PASTEMAC(ch,copys)( *(ct + i*rs_ct + j*cs_ct), \
*(c + i*rs_c + j*cs_c ) ); \ *(c + i*rs_c + j*cs_c ) ); \
@@ -190,8 +195,8 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \
} \ } \
else \ else \
{ \ { \
for ( j = 0; j < nr; ++j ) \ for ( dim_t j = 0; j < n; ++j ) \
for ( i = 0; i < mr; ++i ) \ for ( dim_t i = 0; i < m; ++i ) \
{ \ { \
PASTEMAC(ch,xpbys)( *(ct + i*rs_ct + j*cs_ct), \ PASTEMAC(ch,xpbys)( *(ct + i*rs_ct + j*cs_ct), \
*beta, \ *beta, \
@@ -207,17 +212,19 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \
c_use = ( ctype_r* )c; \ c_use = ( ctype_r* )c; \
rs_c_use = rs_c; \ rs_c_use = rs_c; \
cs_c_use = cs_c; \ cs_c_use = cs_c; \
m_use = m; \
n_use = n; \
\ \
/* Convert the strides from being in units of complex elements to /* Convert the strides and corresponding microtile dimension from being
be in units of real elements. Note that we don't need to check for in units of complex elements to be in units of real elements. */ \
general storage here because that case corresponds to the scenario if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) { cs_c_use *= 2; m_use *= 2; } \
where we are using the ct buffer and its rs_ct/cs_ct strides. */ \ else { rs_c_use *= 2; n_use *= 2; } \
if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) cs_c_use *= 2; \
else rs_c_use *= 2; \
\ \
/* c = beta * c + alpha_r * a * b; */ \ /* c = beta * c + alpha_r * a * b; */ \
rgemm_ukr \ rgemm_ukr \
( \ ( \
m_use, \
n_use, \
k, \ k, \
alpha_r, \ alpha_r, \
a_r, \ a_r, \

View File

@@ -34,6 +34,16 @@
*/ */
//
// gemm kernel parameter struct.
//
typedef struct
{
gemm_ukr_vft ukr;
} gemm_ker_params_t;
// //
// Prototype object-based interfaces. // Prototype object-based interfaces.
// //
@@ -59,32 +69,3 @@ GENPROT( gemm_blk_var3 )
GENPROT( gemm_ker_var1 ) GENPROT( gemm_ker_var1 )
GENPROT( gemm_ker_var2 ) GENPROT( gemm_ker_var2 )
//
// Prototype BLAS-like interfaces with void pointer operands.
//
#undef GENTPROT
#define GENTPROT( ctype, ch, varname ) \
\
void PASTEMAC(ch,varname) \
( \
pack_t schema_a, \
pack_t schema_b, \
dim_t m, \
dim_t n, \
dim_t k, \
void* alpha, \
void* a, inc_t cs_a, inc_t is_a, \
dim_t pd_a, inc_t ps_a, \
void* b, inc_t rs_b, inc_t is_b, \
dim_t pd_b, inc_t ps_b, \
void* beta, \
void* c, inc_t rs_c, inc_t cs_c, \
cntx_t* cntx, \
rntm_t* rntm, \
thrinfo_t* thread \
);
INSERT_GENTPROT_BASIC0( gemm_ker_var2 )

View File

@@ -35,6 +35,7 @@
BLIS_INLINE void bli_gemm_ind_recast_1m_params BLIS_INLINE void bli_gemm_ind_recast_1m_params
( (
num_t* dt_exec, num_t* dt_exec,
num_t* dt_c,
pack_t schema_a, pack_t schema_a,
obj_t* c, obj_t* c,
dim_t* m, dim_t* m,
@@ -57,6 +58,7 @@ BLIS_INLINE void bli_gemm_ind_recast_1m_params
!bli_is_gen_stored( *rs_c, *cs_c ) ) !bli_is_gen_stored( *rs_c, *cs_c ) )
{ {
*dt_exec = bli_dt_proj_to_real( *dt_exec ); *dt_exec = bli_dt_proj_to_real( *dt_exec );
*dt_c = bli_dt_proj_to_real( *dt_c );
if ( bli_is_1e_packed( schema_a ) ) if ( bli_is_1e_packed( schema_a ) )
{ {

View File

@@ -279,6 +279,9 @@ void PASTEMAC(ch,varname) \
/* Save the imaginary stride of A and B to the auxinfo_t object. */ \ /* Save the imaginary stride of A and B to the auxinfo_t object. */ \
bli_auxinfo_set_is_a( is_a, &aux ); \ bli_auxinfo_set_is_a( is_a, &aux ); \
bli_auxinfo_set_is_b( is_b, &aux ); \ bli_auxinfo_set_is_b( is_b, &aux ); \
\
/* Save the desired output datatype (indicating no typecasting). */ \
/*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \
\ \
/* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr)
loop around the microkernel. Here we query the thrinfo_t node for the loop around the microkernel. Here we query the thrinfo_t node for the
@@ -381,43 +384,20 @@ void PASTEMAC(ch,varname) \
And if we're strictly above the diagonal, we do nothing and And if we're strictly above the diagonal, we do nothing and
continue. */ \ continue. */ \
{ \ { \
/* Handle interior and edge cases separately. */ \ /* Invoke the gemm micro-kernel. */ \
if ( m_cur == MR && n_cur == NR ) \ gemm_ukr \
{ \ ( \
/* Invoke the gemm micro-kernel. */ \ m_cur, \
gemm_ukr \ n_cur, \
( \ k, \
k, \ alpha_cast, \
alpha_cast, \ a1, \
a1, \ b1, \
b1, \ beta_cast, \
beta_cast, \ c11, rs_c, cs_c, \
c11, rs_c, cs_c, \ &aux, \
&aux, \ cntx \
cntx \ ); \
); \
} \
else \
{ \
/* Invoke the gemm micro-kernel. */ \
gemm_ukr \
( \
k, \
alpha_cast, \
a1, \
b1, \
zero, \
ct, rs_ct, cs_ct, \
&aux, \
cntx \
); \
\
/* Scale the edge of C and add the result. */ \
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
ct, rs_ct, cs_ct, \
beta_cast, \
c11, rs_c, cs_c ); \
} \
} \ } \
} \ } \
} \ } \
@@ -490,6 +470,8 @@ void PASTEMAC(ch,varname) \
/* Invoke the gemm micro-kernel. */ \ /* Invoke the gemm micro-kernel. */ \
gemm_ukr \ gemm_ukr \
( \ ( \
MR, \
NR, \
k, \ k, \
alpha_cast, \ alpha_cast, \
a1, \ a1, \
@@ -509,43 +491,20 @@ void PASTEMAC(ch,varname) \
} \ } \
else if ( bli_is_strictly_below_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ else if ( bli_is_strictly_below_diag_n( diagoffc_ij, m_cur, n_cur ) ) \
{ \ { \
/* Handle interior and edge cases separately. */ \ /* Invoke the gemm micro-kernel. */ \
if ( m_cur == MR && n_cur == NR ) \ gemm_ukr \
{ \ ( \
/* Invoke the gemm micro-kernel. */ \ m_cur, \
gemm_ukr \ n_cur, \
( \ k, \
k, \ alpha_cast, \
alpha_cast, \ a1, \
a1, \ b1, \
b1, \ beta_cast, \
beta_cast, \ c11, rs_c, cs_c, \
c11, rs_c, cs_c, \ &aux, \
&aux, \ cntx \
cntx \ ); \
); \
} \
else \
{ \
/* Invoke the gemm micro-kernel. */ \
gemm_ukr \
( \
k, \
alpha_cast, \
a1, \
b1, \
zero, \
ct, rs_ct, cs_ct, \
&aux, \
cntx \
); \
\
/* Scale the edge of C and add the result. */ \
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
ct, rs_ct, cs_ct, \
beta_cast, \
c11, rs_c, cs_c ); \
} \
} \ } \
} \ } \
} \ } \

View File

@@ -281,6 +281,9 @@ void PASTEMAC(ch,varname) \
/* Save the imaginary stride of A and B to the auxinfo_t object. */ \ /* Save the imaginary stride of A and B to the auxinfo_t object. */ \
bli_auxinfo_set_is_a( is_a, &aux ); \ bli_auxinfo_set_is_a( is_a, &aux ); \
bli_auxinfo_set_is_b( is_b, &aux ); \ bli_auxinfo_set_is_b( is_b, &aux ); \
\
/* Save the desired output datatype (indicating no typecasting). */ \
/*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \
\ \
/* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr)
loop around the microkernel. Here we query the thrinfo_t node for the loop around the microkernel. Here we query the thrinfo_t node for the
@@ -385,6 +388,8 @@ void PASTEMAC(ch,varname) \
/* Invoke the gemm micro-kernel. */ \ /* Invoke the gemm micro-kernel. */ \
gemm_ukr \ gemm_ukr \
( \ ( \
MR, \
NR, \
k, \ k, \
alpha_cast, \ alpha_cast, \
a1, \ a1, \
@@ -404,43 +409,20 @@ void PASTEMAC(ch,varname) \
} \ } \
else if ( bli_is_strictly_above_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ else if ( bli_is_strictly_above_diag_n( diagoffc_ij, m_cur, n_cur ) ) \
{ \ { \
/* Handle interior and edge cases separately. */ \ /* Invoke the gemm micro-kernel. */ \
if ( m_cur == MR && n_cur == NR ) \ gemm_ukr \
{ \ ( \
/* Invoke the gemm micro-kernel. */ \ m_cur, \
gemm_ukr \ n_cur, \
( \ k, \
k, \ alpha_cast, \
alpha_cast, \ a1, \
a1, \ b1, \
b1, \ beta_cast, \
beta_cast, \ c11, rs_c, cs_c, \
c11, rs_c, cs_c, \ &aux, \
&aux, \ cntx \
cntx \ ); \
); \
} \
else \
{ \
/* Invoke the gemm micro-kernel. */ \
gemm_ukr \
( \
k, \
alpha_cast, \
a1, \
b1, \
zero, \
ct, rs_ct, cs_ct, \
&aux, \
cntx \
); \
\
/* Scale the edge of C and add the result. */ \
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
ct, rs_ct, cs_ct, \
beta_cast, \
c11, rs_c, cs_c ); \
} \
} \ } \
} \ } \
} \ } \
@@ -512,43 +494,20 @@ void PASTEMAC(ch,varname) \
And if we're strictly below the diagonal, we do nothing and And if we're strictly below the diagonal, we do nothing and
continue. */ \ continue. */ \
{ \ { \
/* Handle interior and edge cases separately. */ \ /* Invoke the gemm micro-kernel. */ \
if ( m_cur == MR && n_cur == NR ) \ gemm_ukr \
{ \ ( \
/* Invoke the gemm micro-kernel. */ \ m_cur, \
gemm_ukr \ n_cur, \
( \ k, \
k, \ alpha_cast, \
alpha_cast, \ a1, \
a1, \ b1, \
b1, \ beta_cast, \
beta_cast, \ c11, rs_c, cs_c, \
c11, rs_c, cs_c, \ &aux, \
&aux, \ cntx \
cntx \ ); \
); \
} \
else \
{ \
/* Invoke the gemm micro-kernel. */ \
gemm_ukr \
( \
k, \
alpha_cast, \
a1, \
b1, \
zero, \
ct, rs_ct, cs_ct, \
&aux, \
cntx \
); \
\
/* Scale the edge of C and add the result. */ \
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
ct, rs_ct, cs_ct, \
beta_cast, \
c11, rs_c, cs_c ); \
} \
} \ } \
} \ } \
} \ } \

View File

@@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \
function pointer type. */ \ function pointer type. */ \
PASTECH(ch,gemm_ukr_ft) \ PASTECH(ch,gemm_ukr_ft) \
gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \
\
/* Temporary C buffer for edge cases. Note that the strides of this
temporary buffer are set so that they match the storage of the
original C matrix. For example, if C is column-stored, ct will be
column-stored as well. */ \
ctype ct[ BLIS_STACK_BUF_MAX_SIZE \
/ sizeof( ctype ) ] \
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
\ \
ctype* restrict one = PASTEMAC(ch,1); \ ctype* restrict one = PASTEMAC(ch,1); \
ctype* restrict zero = PASTEMAC(ch,0); \
ctype* restrict a_cast = a; \ ctype* restrict a_cast = a; \
ctype* restrict b_cast = b; \ ctype* restrict b_cast = b; \
ctype* restrict c_cast = c; \ ctype* restrict c_cast = c; \
@@ -254,10 +242,6 @@ void PASTEMAC(ch,varname) \
diagoffa = 0; \ diagoffa = 0; \
c_cast = c_cast + (i )*rs_c; \ c_cast = c_cast + (i )*rs_c; \
} \ } \
\
/* Clear the temporary C buffer in case it has any infs or NaNs. */ \
PASTEMAC(ch,set0s_mxn)( MR, NR, \
ct, rs_ct, cs_ct ); \
\ \
/* Compute number of primary and leftover components of the m and n /* Compute number of primary and leftover components of the m and n
dimensions. */ \ dimensions. */ \
@@ -307,8 +291,8 @@ void PASTEMAC(ch,varname) \
dim_t jr_inc; \ dim_t jr_inc; \
\ \
/* Determine the thread range and increment for the 2nd loop. /* Determine the thread range and increment for the 2nd loop.
NOTE: The definition of bli_thread_range_jrir() will depend on whether NOTE: The definition of bli_thread_range_jrir() will depend on whether
slab or round-robin partitioning was requested at configure-time. \ slab or round-robin partitioning was requested at configure-time. \
NOTE: Parallelism in the 1st loop is disabled for now. */ \ NOTE: Parallelism in the 1st loop is disabled for now. */ \
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \
/*bli_thread_range_jrir_rr( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc );*/ \ /*bli_thread_range_jrir_rr( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc );*/ \
@@ -379,47 +363,20 @@ void PASTEMAC(ch,varname) \
bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_a( a2, &aux ); \
bli_auxinfo_set_next_b( b2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \
\ \
/* Handle interior and edge cases separately. */ \ /* Invoke the gemm micro-kernel. */ \
if ( m_cur == MR && n_cur == NR ) \ gemm_ukr \
{ \ ( \
/* Invoke the gemm micro-kernel. */ \ m_cur, \
gemm_ukr \ n_cur, \
( \ k_a1011, \
k_a1011, \ alpha_cast, \
alpha_cast, \ a1, \
a1, \ b1_i, \
b1_i, \ beta_cast, \
beta_cast, \ c11, rs_c, cs_c, \
c11, rs_c, cs_c, \ &aux, \
&aux, \ cntx \
cntx \ ); \
); \
} \
else \
{ \
/* Copy edge elements of C to the temporary buffer. */ \
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
c11, rs_c, cs_c, \
ct, rs_ct, cs_ct ); \
\
/* Invoke the gemm micro-kernel. */ \
gemm_ukr \
( \
k_a1011, \
alpha_cast, \
a1, \
b1_i, \
beta_cast, \
ct, rs_ct, cs_ct, \
&aux, \
cntx \
); \
\
/* Copy the result to the edge of C. */ \
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
ct, rs_ct, cs_ct, \
c11, rs_c, cs_c ); \
} \
/*}*/ \ /*}*/ \
\ \
a1 += ps_a_cur; \ a1 += ps_a_cur; \
@@ -446,42 +403,20 @@ void PASTEMAC(ch,varname) \
bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_a( a2, &aux ); \
bli_auxinfo_set_next_b( b2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \
\ \
/* Handle interior and edge cases separately. */ \ /* Invoke the gemm micro-kernel. */ \
if ( m_cur == MR && n_cur == NR ) \ gemm_ukr \
{ \ ( \
/* Invoke the gemm micro-kernel. */ \ m_cur, \
gemm_ukr \ n_cur, \
( \ k, \
k, \ alpha_cast, \
alpha_cast, \ a1, \
a1, \ b1, \
b1, \ one, \
one, \ c11, rs_c, cs_c, \
c11, rs_c, cs_c, \ &aux, \
&aux, \ cntx \
cntx \ ); \
); \
} \
else \
{ \
/* Invoke the gemm micro-kernel. */ \
gemm_ukr \
( \
k, \
alpha_cast, \
a1, \
b1, \
zero, \
ct, rs_ct, cs_ct, \
&aux, \
cntx \
); \
\
/* Add the result to the edge of C. */ \
PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \
ct, rs_ct, cs_ct, \
c11, rs_c, cs_c ); \
} \
/*}*/ \ /*}*/ \
\ \
a1 += rstep_a; \ a1 += rstep_a; \

View File

@@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \
function pointer type. */ \ function pointer type. */ \
PASTECH(ch,gemm_ukr_ft) \ PASTECH(ch,gemm_ukr_ft) \
gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \
\
/* Temporary C buffer for edge cases. Note that the strides of this
temporary buffer are set so that they match the storage of the
original C matrix. For example, if C is column-stored, ct will be
column-stored as well. */ \
ctype ct[ BLIS_STACK_BUF_MAX_SIZE \
/ sizeof( ctype ) ] \
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
\ \
ctype* restrict one = PASTEMAC(ch,1); \ ctype* restrict one = PASTEMAC(ch,1); \
ctype* restrict zero = PASTEMAC(ch,0); \
ctype* restrict a_cast = a; \ ctype* restrict a_cast = a; \
ctype* restrict b_cast = b; \ ctype* restrict b_cast = b; \
ctype* restrict c_cast = c; \ ctype* restrict c_cast = c; \
@@ -261,10 +249,6 @@ void PASTEMAC(ch,varname) \
{ \ { \
m = -diagoffa + k; \ m = -diagoffa + k; \
} \ } \
\
/* Clear the temporary C buffer in case it has any infs or NaNs. */ \
PASTEMAC(ch,set0s_mxn)( MR, NR, \
ct, rs_ct, cs_ct ); \
\ \
/* Compute number of primary and leftover components of the m and n /* Compute number of primary and leftover components of the m and n
dimensions. */ \ dimensions. */ \
@@ -386,47 +370,20 @@ void PASTEMAC(ch,varname) \
bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_a( a2, &aux ); \
bli_auxinfo_set_next_b( b2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \
\ \
/* Handle interior and edge cases separately. */ \ /* Invoke the gemm micro-kernel. */ \
if ( m_cur == MR && n_cur == NR ) \ gemm_ukr \
{ \ ( \
/* Invoke the gemm micro-kernel. */ \ m_cur, \
gemm_ukr \ n_cur, \
( \ k_a1112, \
k_a1112, \ alpha_cast, \
alpha_cast, \ a1, \
a1, \ b1_i, \
b1_i, \ beta_cast, \
beta_cast, \ c11, rs_c, cs_c, \
c11, rs_c, cs_c, \ &aux, \
&aux, \ cntx \
cntx \ ); \
); \
} \
else \
{ \
/* Copy edge elements of C to the temporary buffer. */ \
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
c11, rs_c, cs_c, \
ct, rs_ct, cs_ct ); \
\
/* Invoke the gemm micro-kernel. */ \
gemm_ukr \
( \
k_a1112, \
alpha_cast, \
a1, \
b1_i, \
beta_cast, \
ct, rs_ct, cs_ct, \
&aux, \
cntx \
); \
\
/* Copy the result to the edge of C. */ \
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
ct, rs_ct, cs_ct, \
c11, rs_c, cs_c ); \
} \
/*}*/ \ /*}*/ \
\ \
a1 += ps_a_cur; \ a1 += ps_a_cur; \
@@ -453,42 +410,20 @@ void PASTEMAC(ch,varname) \
bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_a( a2, &aux ); \
bli_auxinfo_set_next_b( b2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \
\ \
/* Handle interior and edge cases separately. */ \ /* Invoke the gemm micro-kernel. */ \
if ( m_cur == MR && n_cur == NR ) \ gemm_ukr \
{ \ ( \
/* Invoke the gemm micro-kernel. */ \ m_cur, \
gemm_ukr \ n_cur, \
( \ k, \
k, \ alpha_cast, \
alpha_cast, \ a1, \
a1, \ b1, \
b1, \ one, \
one, \ c11, rs_c, cs_c, \
c11, rs_c, cs_c, \ &aux, \
&aux, \ cntx \
cntx \ ); \
); \
} \
else \
{ \
/* Invoke the gemm micro-kernel. */ \
gemm_ukr \
( \
k, \
alpha_cast, \
a1, \
b1, \
zero, \
ct, rs_ct, cs_ct, \
&aux, \
cntx \
); \
\
/* Add the result to the edge of C. */ \
PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \
ct, rs_ct, cs_ct, \
c11, rs_c, cs_c ); \
} \
/*}*/ \ /*}*/ \
\ \
a1 += rstep_a; \ a1 += rstep_a; \

View File

@@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \
function pointer type. */ \ function pointer type. */ \
PASTECH(ch,gemm_ukr_ft) \ PASTECH(ch,gemm_ukr_ft) \
gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \
\
/* Temporary C buffer for edge cases. Note that the strides of this
temporary buffer are set so that they match the storage of the
original C matrix. For example, if C is column-stored, ct will be
column-stored as well. */ \
ctype ct[ BLIS_STACK_BUF_MAX_SIZE \
/ sizeof( ctype ) ] \
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
\ \
ctype* restrict one = PASTEMAC(ch,1); \ ctype* restrict one = PASTEMAC(ch,1); \
ctype* restrict zero = PASTEMAC(ch,0); \
ctype* restrict a_cast = a; \ ctype* restrict a_cast = a; \
ctype* restrict b_cast = b; \ ctype* restrict b_cast = b; \
ctype* restrict c_cast = c; \ ctype* restrict c_cast = c; \
@@ -261,10 +249,6 @@ void PASTEMAC(ch,varname) \
{ \ { \
n = diagoffb + k; \ n = diagoffb + k; \
} \ } \
\
/* Clear the temporary C buffer in case it has any infs or NaNs. */ \
PASTEMAC(ch,set0s_mxn)( MR, NR, \
ct, rs_ct, cs_ct ); \
\ \
/* Compute number of primary and leftover components of the m and n /* Compute number of primary and leftover components of the m and n
dimensions. */ \ dimensions. */ \
@@ -335,9 +319,9 @@ void PASTEMAC(ch,varname) \
\ \
/* Determine the thread range and increment for the 2nd and 1st loops for /* Determine the thread range and increment for the 2nd and 1st loops for
the initial rectangular region of B (if it exists). the initial rectangular region of B (if it exists).
NOTE: The definition of bli_thread_range_jrir() will depend on whether NOTE: The definition of bli_thread_range_jrir() will depend on whether
slab or round-robin partitioning was requested at configure-time. \ slab or round-robin partitioning was requested at configure-time. \
NOTE: Parallelism in the 1st loop is disabled for now. */ \ NOTE: Parallelism in the 1st loop is disabled for now. */ \
bli_thread_range_jrir( thread, n_iter_rct, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ bli_thread_range_jrir( thread, n_iter_rct, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \
bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \
\ \
@@ -382,42 +366,20 @@ void PASTEMAC(ch,varname) \
bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_a( a2, &aux ); \
bli_auxinfo_set_next_b( b2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \
\ \
/* Handle interior and edge cases separately. */ \ /* Invoke the gemm micro-kernel. */ \
if ( m_cur == MR && n_cur == NR ) \ gemm_ukr \
{ \ ( \
/* Invoke the gemm micro-kernel. */ \ m_cur, \
gemm_ukr \ n_cur, \
( \ k, \
k, \ alpha_cast, \
alpha_cast, \ a1, \
a1, \ b1, \
b1, \ one, \
one, \ c11, rs_c, cs_c, \
c11, rs_c, cs_c, \ &aux, \
&aux, \ cntx \
cntx \ ); \
); \
} \
else \
{ \
/* Invoke the gemm micro-kernel. */ \
gemm_ukr \
( \
k, \
alpha_cast, \
a1, \
b1, \
zero, \
ct, rs_ct, cs_ct, \
&aux, \
cntx \
); \
\
/* Add the result to the edge of C. */ \
PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \
ct, rs_ct, cs_ct, \
c11, rs_c, cs_c ); \
} \
} \ } \
} \ } \
} \ } \
@@ -501,47 +463,20 @@ void PASTEMAC(ch,varname) \
bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_a( a2, &aux ); \
bli_auxinfo_set_next_b( b2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \
\ \
/* Handle interior and edge cases separately. */ \ /* Invoke the gemm micro-kernel. */ \
if ( m_cur == MR && n_cur == NR ) \ gemm_ukr \
{ \ ( \
/* Invoke the gemm micro-kernel. */ \ m_cur, \
gemm_ukr \ n_cur, \
( \ k_b1121, \
k_b1121, \ alpha_cast, \
alpha_cast, \ a1_i, \
a1_i, \ b1, \
b1, \ beta_cast, \
beta_cast, \ c11, rs_c, cs_c, \
c11, rs_c, cs_c, \ &aux, \
&aux, \ cntx \
cntx \ ); \
); \
} \
else \
{ \
/* Copy edge elements of C to the temporary buffer. */ \
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
c11, rs_c, cs_c, \
ct, rs_ct, cs_ct ); \
\
/* Invoke the gemm micro-kernel. */ \
gemm_ukr \
( \
k_b1121, \
alpha_cast, \
a1_i, \
b1, \
beta_cast, \
ct, rs_ct, cs_ct, \
&aux, \
cntx \
); \
\
/* Copy the result to the edge of C. */ \
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
ct, rs_ct, cs_ct, \
c11, rs_c, cs_c ); \
} \
} \ } \
\ \
a1 += rstep_a; \ a1 += rstep_a; \

View File

@@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \
function pointer type. */ \ function pointer type. */ \
PASTECH(ch,gemm_ukr_ft) \ PASTECH(ch,gemm_ukr_ft) \
gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \
\
/* Temporary C buffer for edge cases. Note that the strides of this
temporary buffer are set so that they match the storage of the
original C matrix. For example, if C is column-stored, ct will be
column-stored as well. */ \
ctype ct[ BLIS_STACK_BUF_MAX_SIZE \
/ sizeof( ctype ) ] \
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
\ \
ctype* restrict one = PASTEMAC(ch,1); \ ctype* restrict one = PASTEMAC(ch,1); \
ctype* restrict zero = PASTEMAC(ch,0); \
ctype* restrict a_cast = a; \ ctype* restrict a_cast = a; \
ctype* restrict b_cast = b; \ ctype* restrict b_cast = b; \
ctype* restrict c_cast = c; \ ctype* restrict c_cast = c; \
@@ -262,10 +250,6 @@ void PASTEMAC(ch,varname) \
{ \ { \
k = -diagoffb + n; \ k = -diagoffb + n; \
} \ } \
\
/* Clear the temporary C buffer in case it has any infs or NaNs. */ \
PASTEMAC(ch,set0s_mxn)( MR, NR, \
ct, rs_ct, cs_ct ); \
\ \
/* Compute number of primary and leftover components of the m and n /* Compute number of primary and leftover components of the m and n
dimensions. */ \ dimensions. */ \
@@ -410,47 +394,20 @@ void PASTEMAC(ch,varname) \
bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_a( a2, &aux ); \
bli_auxinfo_set_next_b( b2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \
\ \
/* Handle interior and edge cases separately. */ \ /* Invoke the gemm micro-kernel. */ \
if ( m_cur == MR && n_cur == NR ) \ gemm_ukr \
{ \ ( \
/* Invoke the gemm micro-kernel. */ \ m_cur, \
gemm_ukr \ n_cur, \
( \ k_b0111, \
k_b0111, \ alpha_cast, \
alpha_cast, \ a1_i, \
a1_i, \ b1, \
b1, \ beta_cast, \
beta_cast, \ c11, rs_c, cs_c, \
c11, rs_c, cs_c, \ &aux, \
&aux, \ cntx \
cntx \ ); \
); \
} \
else \
{ \
/* Copy edge elements of C to the temporary buffer. */ \
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
c11, rs_c, cs_c, \
ct, rs_ct, cs_ct ); \
\
/* Invoke the gemm micro-kernel. */ \
gemm_ukr \
( \
k_b0111, \
alpha_cast, \
a1_i, \
b1, \
beta_cast, \
ct, rs_ct, cs_ct, \
&aux, \
cntx \
); \
\
/* Copy the result to the edge of C. */ \
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
ct, rs_ct, cs_ct, \
c11, rs_c, cs_c ); \
} \
} \ } \
\ \
a1 += rstep_a; \ a1 += rstep_a; \
@@ -476,9 +433,9 @@ void PASTEMAC(ch,varname) \
bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \
\ \
/* Advance the start and end iteration offsets for the rectangular region /* Advance the start and end iteration offsets for the rectangular region
by the number of iterations used for the triangular region. */ \ by the number of iterations used for the triangular region. */ \
jr_start += n_iter_tri; \ jr_start += n_iter_tri; \
jr_end += n_iter_tri; \ jr_end += n_iter_tri; \
jb0 = n_iter_tri; \ jb0 = n_iter_tri; \
\ \
/* Save the resulting value of b1 from the previous loop since it represents /* Save the resulting value of b1 from the previous loop since it represents
@@ -496,7 +453,7 @@ void PASTEMAC(ch,varname) \
the starting address of the rectangular region (which is already the starting address of the rectangular region (which is already
n_iter_tri logical iterations through B). */ \ n_iter_tri logical iterations through B). */ \
b1 = b_cast + (j-jb0) * cstep_b; \ b1 = b_cast + (j-jb0) * cstep_b; \
c1 = c_cast + j * cstep_c; \ c1 = c_cast + j * cstep_c; \
\ \
n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \
\ \
@@ -533,42 +490,20 @@ void PASTEMAC(ch,varname) \
bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_a( a2, &aux ); \
bli_auxinfo_set_next_b( b2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \
\ \
/* Handle interior and edge cases separately. */ \ /* Invoke the gemm micro-kernel. */ \
if ( m_cur == MR && n_cur == NR ) \ gemm_ukr \
{ \ ( \
/* Invoke the gemm micro-kernel. */ \ m_cur, \
gemm_ukr \ n_cur, \
( \ k, \
k, \ alpha_cast, \
alpha_cast, \ a1, \
a1, \ b1, \
b1, \ one, \
one, \ c11, rs_c, cs_c, \
c11, rs_c, cs_c, \ &aux, \
&aux, \ cntx \
cntx \ ); \
); \
} \
else \
{ \
/* Invoke the gemm micro-kernel. */ \
gemm_ukr \
( \
k, \
alpha_cast, \
a1, \
b1, \
zero, \
ct, rs_ct, cs_ct, \
&aux, \
cntx \
); \
\
/* Add the result to the edge of C. */ \
PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \
ct, rs_ct, cs_ct, \
c11, rs_c, cs_c ); \
} \
} \ } \
} \ } \
} \ } \

View File

@@ -40,27 +40,30 @@ cntl_t* bli_trsm_cntl_create
rntm_t* rntm, rntm_t* rntm,
side_t side, side_t side,
pack_t schema_a, pack_t schema_a,
pack_t schema_b pack_t schema_b,
void_fp ker
) )
{ {
if ( bli_is_left( side ) ) if ( bli_is_left( side ) )
return bli_trsm_l_cntl_create( rntm, schema_a, schema_b ); return bli_trsm_l_cntl_create( rntm, schema_a, schema_b, ker );
else else
return bli_trsm_r_cntl_create( rntm, schema_a, schema_b ); return bli_trsm_r_cntl_create( rntm, schema_a, schema_b, ker );
} }
cntl_t* bli_trsm_l_cntl_create cntl_t* bli_trsm_l_cntl_create
( (
rntm_t* rntm, rntm_t* rntm,
pack_t schema_a, pack_t schema_a,
pack_t schema_b pack_t schema_b,
void_fp ker
) )
{ {
void_fp macro_kernel_p; void_fp macro_kernel_p;
// Use the function pointer to the macrokernels that use slab // Set the default macrokernel. If a non-NULL kernel function pointer is
// assignment of micropanels to threads in the jr and ir loops. // passed in, we use that instead.
macro_kernel_p = bli_trsm_xx_ker_var2; macro_kernel_p = bli_trsm_xx_ker_var2;
if ( ker ) macro_kernel_p = ker;
const opid_t family = BLIS_TRSM; const opid_t family = BLIS_TRSM;
@@ -202,11 +205,15 @@ cntl_t* bli_trsm_r_cntl_create
( (
rntm_t* rntm, rntm_t* rntm,
pack_t schema_a, pack_t schema_a,
pack_t schema_b pack_t schema_b,
void_fp ker
) )
{ {
// NOTE: trsm macrokernels are presently disabled for right-side execution. // NOTE: trsm macrokernels are presently disabled for right-side execution.
// Set the default macrokernel. If a non-NULL kernel function pointer is
// passed in, we use that instead.
void_fp macro_kernel_p = bli_trsm_xx_ker_var2; void_fp macro_kernel_p = bli_trsm_xx_ker_var2;
if ( ker ) macro_kernel_p = ker;
const opid_t family = BLIS_TRSM; const opid_t family = BLIS_TRSM;

View File

@@ -38,21 +38,24 @@ cntl_t* bli_trsm_cntl_create
rntm_t* rntm, rntm_t* rntm,
side_t side, side_t side,
pack_t schema_a, pack_t schema_a,
pack_t schema_b pack_t schema_b,
void_fp ker
); );
cntl_t* bli_trsm_l_cntl_create cntl_t* bli_trsm_l_cntl_create
( (
rntm_t* rntm, rntm_t* rntm,
pack_t schema_a, pack_t schema_a,
pack_t schema_b pack_t schema_b,
void_fp ker
); );
cntl_t* bli_trsm_r_cntl_create cntl_t* bli_trsm_r_cntl_create
( (
rntm_t* rntm, rntm_t* rntm,
pack_t schema_a, pack_t schema_a,
pack_t schema_b pack_t schema_b,
void_fp ker
); );
void bli_trsm_cntl_free void bli_trsm_cntl_free

View File

@@ -183,7 +183,6 @@ void PASTEMAC(ch,varname) \
const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \
const inc_t cs_ct = ( col_pref ? MR : 1 ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \
\ \
ctype* restrict zero = PASTEMAC(ch,0); \
ctype* restrict minus_one = PASTEMAC(ch,m1); \ ctype* restrict minus_one = PASTEMAC(ch,m1); \
ctype* restrict a_cast = a; \ ctype* restrict a_cast = a; \
ctype* restrict b_cast = b; \ ctype* restrict b_cast = b; \
@@ -470,43 +469,20 @@ void PASTEMAC(ch,varname) \
bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_a( a2, &aux ); \
bli_auxinfo_set_next_b( b2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \
\ \
/* Handle interior and edge cases separately. */ \ /* Invoke the gemm micro-kernel. */ \
if ( m_cur == MR && n_cur == NR ) \ gemm_ukr \
{ \ ( \
/* Invoke the gemm micro-kernel. */ \ m_cur, \
gemm_ukr \ n_cur, \
( \ k, \
k, \ minus_one, \
minus_one, \ a1, \
a1, \ b1, \
b1, \ alpha2_cast, \
alpha2_cast, \ c11, rs_c, cs_c, \
c11, rs_c, cs_c, \ &aux, \
&aux, \ cntx \
cntx \ ); \
); \
} \
else \
{ \
/* Invoke the gemm micro-kernel. */ \
gemm_ukr \
( \
k, \
minus_one, \
a1, \
b1, \
zero, \
ct, rs_ct, cs_ct, \
&aux, \
cntx \
); \
\
/* Add the result to the edge of C. */ \
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
ct, rs_ct, cs_ct, \
alpha2_cast, \
c11, rs_c, cs_c ); \
} \
\ \
a1 += rstep_a; \ a1 += rstep_a; \
} \ } \

View File

@@ -183,7 +183,6 @@ void PASTEMAC(ch,varname) \
const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \
const inc_t cs_ct = ( col_pref ? MR : 1 ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \
\ \
ctype* restrict zero = PASTEMAC(ch,0); \
ctype* restrict minus_one = PASTEMAC(ch,m1); \ ctype* restrict minus_one = PASTEMAC(ch,m1); \
ctype* restrict a_cast = a; \ ctype* restrict a_cast = a; \
ctype* restrict b_cast = b; \ ctype* restrict b_cast = b; \
@@ -480,43 +479,20 @@ void PASTEMAC(ch,varname) \
bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_a( a2, &aux ); \
bli_auxinfo_set_next_b( b2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \
\ \
/* Handle interior and edge cases separately. */ \ /* Invoke the gemm micro-kernel. */ \
if ( m_cur == MR && n_cur == NR ) \ gemm_ukr \
{ \ ( \
/* Invoke the gemm micro-kernel. */ \ m_cur, \
gemm_ukr \ n_cur, \
( \ k, \
k, \ minus_one, \
minus_one, \ a1, \
a1, \ b1, \
b1, \ alpha2_cast, \
alpha2_cast, \ c11, rs_c, cs_c, \
c11, rs_c, cs_c, \ &aux, \
&aux, \ cntx \
cntx \ ); \
); \
} \
else \
{ \
/* Invoke the gemm micro-kernel. */ \
gemm_ukr \
( \
k, \
minus_one, \
a1, \
b1, \
zero, \
ct, rs_ct, cs_ct, \
&aux, \
cntx \
); \
\
/* Add the result to the edge of C. */ \
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
ct, rs_ct, cs_ct, \
alpha2_cast, \
c11, rs_c, cs_c ); \
} \
\ \
a1 += rstep_a; \ a1 += rstep_a; \
} \ } \

View File

@@ -188,7 +188,6 @@ void PASTEMAC(ch,varname) \
const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \
const inc_t cs_ct = ( col_pref ? MR : 1 ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \
\ \
ctype* restrict zero = PASTEMAC(ch,0); \
ctype* restrict minus_one = PASTEMAC(ch,m1); \ ctype* restrict minus_one = PASTEMAC(ch,m1); \
ctype* restrict a_cast = a; \ ctype* restrict a_cast = a; \
ctype* restrict b_cast = b; \ ctype* restrict b_cast = b; \
@@ -499,43 +498,20 @@ void PASTEMAC(ch,varname) \
bli_auxinfo_set_next_a( b2, &aux ); \ bli_auxinfo_set_next_a( b2, &aux ); \
bli_auxinfo_set_next_b( a2, &aux ); \ bli_auxinfo_set_next_b( a2, &aux ); \
\ \
/* Handle interior and edge cases separately. */ \ /* Invoke the gemm micro-kernel. */ \
if ( m_cur == MR && n_cur == NR ) \ gemm_ukr \
{ \ ( \
/* Invoke the gemm micro-kernel. */ \ m_cur, \
gemm_ukr \ n_cur, \
( \ k, \
k, \ minus_one, \
minus_one, \ b1, \
b1, \ a1, \
a1, \ alpha2_cast, \
alpha2_cast, \ c11, cs_c, rs_c, \
c11, cs_c, rs_c, \ &aux, \
&aux, \ cntx \
cntx \ ); \
); \
} \
else \
{ \
/* Invoke the gemm micro-kernel. */ \
gemm_ukr \
( \
k, \
minus_one, \
b1, \
a1, \
zero, \
ct, cs_ct, rs_ct, \
&aux, \
cntx \
); \
\
/* Add the result to the edge of C. */ \
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
ct, rs_ct, cs_ct, \
alpha2_cast, \
c11, rs_c, cs_c ); \
} \
} \ } \
\ \
a1 += rstep_a; \ a1 += rstep_a; \

View File

@@ -188,7 +188,6 @@ void PASTEMAC(ch,varname) \
const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \
const inc_t cs_ct = ( col_pref ? MR : 1 ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \
\ \
ctype* restrict zero = PASTEMAC(ch,0); \
ctype* restrict minus_one = PASTEMAC(ch,m1); \ ctype* restrict minus_one = PASTEMAC(ch,m1); \
ctype* restrict a_cast = a; \ ctype* restrict a_cast = a; \
ctype* restrict b_cast = b; \ ctype* restrict b_cast = b; \
@@ -492,43 +491,20 @@ void PASTEMAC(ch,varname) \
bli_auxinfo_set_next_a( b2, &aux ); \ bli_auxinfo_set_next_a( b2, &aux ); \
bli_auxinfo_set_next_b( a2, &aux ); \ bli_auxinfo_set_next_b( a2, &aux ); \
\ \
/* Handle interior and edge cases separately. */ \ /* Invoke the gemm micro-kernel. */ \
if ( m_cur == MR && n_cur == NR ) \ gemm_ukr \
{ \ ( \
/* Invoke the gemm micro-kernel. */ \ m_cur, \
gemm_ukr \ n_cur, \
( \ k, \
k, \ minus_one, \
minus_one, \ b1, \
b1, \ a1, \
a1, \ alpha2_cast, \
alpha2_cast, \ c11, cs_c, rs_c, \
c11, cs_c, rs_c, \ &aux, \
&aux, \ cntx \
cntx \ ); \
); \
} \
else \
{ \
/* Invoke the gemm micro-kernel. */ \
gemm_ukr \
( \
k, \
minus_one, \
b1, \
a1, \
zero, \
ct, cs_ct, rs_ct, \
&aux, \
cntx \
); \
\
/* Add the result to the edge of C. */ \
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
ct, rs_ct, cs_ct, \
alpha2_cast, \
c11, rs_c, cs_c ); \
} \
} \ } \
\ \
a1 += rstep_a; \ a1 += rstep_a; \

View File

@@ -74,6 +74,15 @@ BLIS_INLINE inc_t bli_auxinfo_ps_b( auxinfo_t* ai )
return ai->ps_b; return ai->ps_b;
} }
BLIS_INLINE void_fp bli_auxinfo_ukr( auxinfo_t* ai )
{
return ai->ukr;
}
BLIS_INLINE void* bli_auxinfo_params( auxinfo_t* ai )
{
return ai->params;
}
// auxinfo_t field modification // auxinfo_t field modification
@@ -118,5 +127,14 @@ BLIS_INLINE void bli_auxinfo_set_ps_b( inc_t ps, auxinfo_t* ai )
ai->ps_b = ps; ai->ps_b = ps;
} }
BLIS_INLINE void bli_auxinfo_set_ukr( void_fp ukr, auxinfo_t* ai )
{
ai->ukr = ukr;
}
BLIS_INLINE void bli_auxinfo_set_params( void* params, auxinfo_t* ai )
{
ai->params = params;
}
#endif #endif

View File

@@ -0,0 +1,109 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2021, The University of Texas at Austin
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef BLIS_EDGE_CASE_MACRO_DEFS_H
#define BLIS_EDGE_CASE_MACRO_DEFS_H
// Helper macros for edge-case handling within gemm microkernels.
#define GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major) \
\
PASTEMAC(ch,ctype)* restrict _beta = beta; \
PASTEMAC(ch,ctype)* restrict _c = c; \
const inc_t _rs_c = rs_c; \
const inc_t _cs_c = cs_c; \
PASTEMAC(ch,ctype) _ct[ BLIS_STACK_BUF_MAX_SIZE / sizeof( PASTEMAC(ch,type) ) ] \
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
const inc_t _rs_ct = row_major ? nr : 1; \
const inc_t _cs_ct = row_major ? 1 : mr;
#define GEMM_UKR_SETUP_CT_POST(ch) \
\
PASTEMAC(ch,ctype) _zero; \
PASTEMAC(ch,set0s)( _zero ); \
\
if ( _use_ct ) \
{ \
c = _ct; \
rs_c = _rs_ct; \
cs_c = _cs_ct; \
beta = &_zero; \
}
#define GEMM_UKR_SETUP_CT(ch,mr,nr,row_major) \
\
GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \
const bool _use_ct = ( row_major ? cs_c != 1 : rs_c != 1 ) || \
m != mr || n != nr; \
GEMM_UKR_SETUP_CT_POST(ch);
#define GEMM_UKR_SETUP_CT_AMBI(ch,mr,nr,row_major) \
\
GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \
const bool _use_ct = ( cs_c != 1 && rs_c != 1 ) || \
m != mr || n != nr; \
GEMM_UKR_SETUP_CT_POST(ch);
#define GEMM_UKR_SETUP_CT_ANY(ch,mr,nr,row_major) \
\
GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \
const bool _use_ct = m != mr || n != nr; \
GEMM_UKR_SETUP_CT_POST(ch);
#define GEMM_UKR_SETUP_CT_ALIGNED(ch,mr,nr,row_major,alignment) \
\
GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \
const bool _use_ct = ( row_major ? cs_c != 1 : rs_c != 1 ) || \
m != mr || n != nr || \
( (uintptr_t)_c % alignment ) || \
( ( ( row_major ? _rs_c : _cs_c )*sizeof( PASTEMAC(ch,ctype) ) ) % alignment ); \
GEMM_UKR_SETUP_CT_POST(ch);
#define GEMM_UKR_FLUSH_CT(ch) \
\
if ( _use_ct ) \
{ \
PASTEMAC(ch,xpbys_mxn) \
( \
m, n, \
_ct, _rs_ct, _cs_ct, \
_beta, \
_c, _rs_c, _cs_c \
); \
} \
#endif

View File

@@ -98,6 +98,7 @@
#include "bli_gentprot_macro_defs.h" #include "bli_gentprot_macro_defs.h"
#include "bli_misc_macro_defs.h" #include "bli_misc_macro_defs.h"
#include "bli_edge_case_macro_defs.h"
#include "bli_param_macro_defs.h" #include "bli_param_macro_defs.h"
#include "bli_obj_macro_defs.h" #include "bli_obj_macro_defs.h"
#include "bli_complex_macro_defs.h" #include "bli_complex_macro_defs.h"

View File

@@ -1144,6 +1144,13 @@ typedef struct
inc_t ps_a; inc_t ps_a;
inc_t ps_b; inc_t ps_b;
// The type to convert to on output.
//num_t dt_on_output;
// (Virtual) microkernel address and additional parameters.
void_fp ukr;
void* params;
} auxinfo_t; } auxinfo_t;

View File

@@ -42,9 +42,13 @@
// 2vx10 microkernels. // 2vx10 microkernels.
#include "armsve_asm_2vx10cmplx.h" #include "armsve_asm_2vx10cmplx.h"
#include "arm_sve.h"
void bli_cgemm_armsve_asm_2vx10_unindexed void bli_cgemm_armsve_asm_2vx10_unindexed
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
scomplex* restrict alpha, scomplex* restrict alpha,
scomplex* restrict a, scomplex* restrict a,
scomplex* restrict b, scomplex* restrict b,
@@ -59,12 +63,15 @@ void bli_cgemm_armsve_asm_2vx10_unindexed
// Typecast local copies of integers in case dim_t and inc_t are a // Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions. // different size than is expected by load instructions.
uint64_t k_mker = k0 / 4; uint64_t k_mker = k / 4;
uint64_t k_left = k0 % 4; uint64_t k_left = k % 4;
uint64_t rs_c = rs_c0; uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0; uint64_t cs_c = cs_c0;
uint64_t info = 0; uint64_t info = 0;
uint64_t mr = svcntw();
GEMM_UKR_SETUP_CT( c, mr, 10, false );
__asm__ volatile ( __asm__ volatile (
// " ldr x0, %[a] \n\t" // " ldr x0, %[a] \n\t"
// " ldr x1, %[b] \n\t" // " ldr x1, %[b] \n\t"
@@ -310,5 +317,7 @@ GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z28,%2,%4,x16)
"z24","z25","z26","z27", "z24","z25","z26","z27",
"z28","z29","z30","z31" "z28","z29","z30","z31"
); );
GEMM_UKR_FLUSH_CT( c );
} }

View File

@@ -42,9 +42,13 @@
// 2vx10 microkernels. // 2vx10 microkernels.
#include "armsve_asm_2vx10.h" #include "armsve_asm_2vx10.h"
#include "arm_sve.h"
void bli_dgemm_armsve_asm_2vx10_unindexed void bli_dgemm_armsve_asm_2vx10_unindexed
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
double* restrict alpha, double* restrict alpha,
double* restrict a, double* restrict a,
double* restrict b, double* restrict b,
@@ -59,11 +63,14 @@ void bli_dgemm_armsve_asm_2vx10_unindexed
// Typecast local copies of integers in case dim_t and inc_t are a // Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions. // different size than is expected by load instructions.
uint64_t k_mker = k0 / 4; uint64_t k_mker = k / 4;
uint64_t k_left = k0 % 4; uint64_t k_left = k % 4;
uint64_t rs_c = rs_c0; uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0; uint64_t cs_c = cs_c0;
uint64_t mr = 2*svcntd();
GEMM_UKR_SETUP_CT( d, mr, 10, false );
__asm__ volatile ( __asm__ volatile (
" ldr x0, %[a] \n\t" " ldr x0, %[a] \n\t"
" ldr x1, %[b] \n\t" " ldr x1, %[b] \n\t"
@@ -324,5 +331,7 @@ GEMM_C_STORE_UKER_G(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z30,p0,p0,x5,x7,x8,x
"z24","z25","z26","z27", "z24","z25","z26","z27",
"z28","z29","z30","z31" "z28","z29","z30","z31"
); );
GEMM_UKR_FLUSH_CT( d );
} }

View File

@@ -42,9 +42,13 @@
// 2vx10 microkernels. // 2vx10 microkernels.
#include "armsve_asm_2vx10.h" #include "armsve_asm_2vx10.h"
#include "arm_sve.h"
void bli_sgemm_armsve_asm_2vx10_unindexed void bli_sgemm_armsve_asm_2vx10_unindexed
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
float* restrict alpha, float* restrict alpha,
float* restrict a, float* restrict a,
float* restrict b, float* restrict b,
@@ -59,11 +63,14 @@ void bli_sgemm_armsve_asm_2vx10_unindexed
// Typecast local copies of integers in case dim_t and inc_t are a // Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions. // different size than is expected by load instructions.
uint64_t k_mker = k0 / 4; uint64_t k_mker = k / 4;
uint64_t k_left = k0 % 4; uint64_t k_left = k % 4;
uint64_t rs_c = rs_c0; uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0; uint64_t cs_c = cs_c0;
uint64_t mr = 2*svcntw();
GEMM_UKR_SETUP_CT( s, mr, 10, false );
__asm__ volatile ( __asm__ volatile (
" ldr x0, %[a] \n\t" " ldr x0, %[a] \n\t"
" ldr x1, %[b] \n\t" " ldr x1, %[b] \n\t"
@@ -310,5 +317,7 @@ GEMM_C_STORE_UKER_G(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z30,p0,p0,x5,x7,x8,x
"z24","z25","z26","z27", "z24","z25","z26","z27",
"z28","z29","z30","z31" "z28","z29","z30","z31"
); );
GEMM_UKR_FLUSH_CT( s );
} }

View File

@@ -42,9 +42,13 @@
// 2vx10 microkernels. // 2vx10 microkernels.
#include "armsve_asm_2vx10cmplx.h" #include "armsve_asm_2vx10cmplx.h"
#include "arm_sve.h"
void bli_zgemm_armsve_asm_2vx10_unindexed void bli_zgemm_armsve_asm_2vx10_unindexed
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
dcomplex* restrict alpha, dcomplex* restrict alpha,
dcomplex* restrict a, dcomplex* restrict a,
dcomplex* restrict b, dcomplex* restrict b,
@@ -59,12 +63,15 @@ void bli_zgemm_armsve_asm_2vx10_unindexed
// Typecast local copies of integers in case dim_t and inc_t are a // Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions. // different size than is expected by load instructions.
uint64_t k_mker = k0 / 4; uint64_t k_mker = k / 4;
uint64_t k_left = k0 % 4; uint64_t k_left = k % 4;
uint64_t rs_c = rs_c0; uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0; uint64_t cs_c = cs_c0;
uint64_t info = 0; uint64_t info = 0;
uint64_t mr = svcntd();
GEMM_UKR_SETUP_CT( z, mr, 10, false );
__asm__ volatile ( __asm__ volatile (
// " ldr x0, %[a] \n\t" // " ldr x0, %[a] \n\t"
// " ldr x1, %[b] \n\t" // " ldr x1, %[b] \n\t"
@@ -309,5 +316,7 @@ GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z28,%2,%4,x16)
"z24","z25","z26","z27", "z24","z25","z26","z27",
"z28","z29","z30","z31" "z28","z29","z30","z31"
); );
GEMM_UKR_FLUSH_CT( z );
} }

View File

@@ -42,9 +42,13 @@
// 2vx7 microkernels. // 2vx7 microkernels.
#include "armsve_asm_2vx7cmplx.h" #include "armsve_asm_2vx7cmplx.h"
#include "arm_sve.h"
void bli_zgemm_armsve_asm_2vx7_unindexed void bli_zgemm_armsve_asm_2vx7_unindexed
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
dcomplex* restrict alpha, dcomplex* restrict alpha,
dcomplex* restrict a, dcomplex* restrict a,
dcomplex* restrict b, dcomplex* restrict b,
@@ -59,12 +63,15 @@ void bli_zgemm_armsve_asm_2vx7_unindexed
// Typecast local copies of integers in case dim_t and inc_t are a // Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions. // different size than is expected by load instructions.
uint64_t k_mker = k0 / 4; uint64_t k_mker = k / 4;
uint64_t k_left = k0 % 4; uint64_t k_left = k % 4;
uint64_t rs_c = rs_c0; uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0; uint64_t cs_c = cs_c0;
uint64_t info = 0; uint64_t info = 0;
uint64_t mr = svcntd();
GEMM_UKR_SETUP_CT( z, mr, 7, false );
__asm__ volatile ( __asm__ volatile (
// " ldr x0, %[a] \n\t" // " ldr x0, %[a] \n\t"
// " ldr x1, %[b] \n\t" // " ldr x1, %[b] \n\t"
@@ -261,6 +268,8 @@ GEMM_CCMPLX_STORE_COL7_G(z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27
"z24","z25","z26","z27", "z24","z25","z26","z27",
"z28","z29","z30","z31" "z28","z29","z30","z31"
); );
GEMM_UKR_FLUSH_CT( z );
} }

View File

@@ -42,9 +42,13 @@
// 2vx8 microkernels. // 2vx8 microkernels.
#include "armsve_asm_2vx8cmplx.h" #include "armsve_asm_2vx8cmplx.h"
#include "arm_sve.h"
void bli_zgemm_armsve_asm_2vx8_unindexed void bli_zgemm_armsve_asm_2vx8_unindexed
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
dcomplex* restrict alpha, dcomplex* restrict alpha,
dcomplex* restrict a, dcomplex* restrict a,
dcomplex* restrict b, dcomplex* restrict b,
@@ -59,12 +63,15 @@ void bli_zgemm_armsve_asm_2vx8_unindexed
// Typecast local copies of integers in case dim_t and inc_t are a // Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions. // different size than is expected by load instructions.
uint64_t k_mker = k0 / 6; uint64_t k_mker = k / 6;
uint64_t k_left = k0 % 6; uint64_t k_left = k % 6;
uint64_t rs_c = rs_c0; uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0; uint64_t cs_c = cs_c0;
uint64_t info = 0; uint64_t info = 0;
uint64_t mr = svcntd();
GEMM_UKR_SETUP_CT( z, mr, 8, false );
__asm__ volatile ( __asm__ volatile (
// " ldr x0, %[a] \n\t" // " ldr x0, %[a] \n\t"
// " ldr x1, %[b] \n\t" // " ldr x1, %[b] \n\t"
@@ -286,5 +293,7 @@ GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z16,%2,%4,x16)
"z24","z25","z26","z27", "z24","z25","z26","z27",
"z28","z29","z30","z31" "z28","z29","z30","z31"
); );
GEMM_UKR_FLUSH_CT( z );
} }

View File

@@ -48,23 +48,23 @@ void bli_sgemm_armv7a_ker_4x4
void bli_sgemm_armv7a_asm_4x4 void bli_sgemm_armv7a_asm_4x4
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
float* restrict alpha, float* restrict alpha,
float* restrict a, float* restrict a,
float* restrict b, float* restrict b,
float* restrict beta, float* restrict beta,
float* restrict c, inc_t rs_c0, inc_t cs_c0, float* restrict c, inc_t rs_c, inc_t cs_c,
auxinfo_t* restrict data, auxinfo_t* restrict data,
cntx_t* restrict cntx cntx_t* restrict cntx
) )
{ {
// Typecast local copies of integers in case dim_t and inc_t are a // Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions. // different size than is expected by load instructions.
uint32_t k = k0; GEMM_UKR_SETUP_CT_ANY( s, 4, 4, false );
uint32_t rs_c = rs_c0;
uint32_t cs_c = cs_c0;
bli_sgemm_armv7a_ker_4x4( k, alpha, a, b, beta, c, rs_c, cs_c, data ); bli_sgemm_armv7a_ker_4x4( k, alpha, a, b, beta, c, rs_c, cs_c, data );
GEMM_UKR_FLUSH_CT( s );
} }
@@ -83,23 +83,23 @@ void bli_dgemm_armv7a_ker_4x4
void bli_dgemm_armv7a_asm_4x4 void bli_dgemm_armv7a_asm_4x4
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
double* restrict alpha, double* restrict alpha,
double* restrict a, double* restrict a,
double* restrict b, double* restrict b,
double* restrict beta, double* restrict beta,
double* restrict c, inc_t rs_c0, inc_t cs_c0, double* restrict c, inc_t rs_c, inc_t cs_c,
auxinfo_t* restrict data, auxinfo_t* restrict data,
cntx_t* restrict cntx cntx_t* restrict cntx
) )
{ {
// Typecast local copies of integers in case dim_t and inc_t are a // Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions. // different size than is expected by load instructions.
uint32_t k = k0; GEMM_UKR_SETUP_CT_ANY( d, 4, 4, false );
uint32_t rs_c = rs_c0;
uint32_t cs_c = cs_c0;
bli_dgemm_armv7a_ker_4x4( k, alpha, a, b, beta, c, rs_c, cs_c, data ); bli_dgemm_armv7a_ker_4x4( k, alpha, a, b, beta, c, rs_c, cs_c, data );
GEMM_UKR_FLUSH_CT( d );
} }
@@ -118,23 +118,23 @@ void bli_cgemm_armv7a_ker_2x2
void bli_cgemm_armv7a_asm_2x2 void bli_cgemm_armv7a_asm_2x2
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
scomplex* restrict alpha, scomplex* restrict alpha,
scomplex* restrict a, scomplex* restrict a,
scomplex* restrict b, scomplex* restrict b,
scomplex* restrict beta, scomplex* restrict beta,
scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, scomplex* restrict c, inc_t rs_c, inc_t cs_c,
auxinfo_t* restrict data, auxinfo_t* restrict data,
cntx_t* restrict cntx cntx_t* restrict cntx
) )
{ {
// Typecast local copies of integers in case dim_t and inc_t are a // Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions. // different size than is expected by load instructions.
uint32_t k = k0; GEMM_UKR_SETUP_CT_ANY( c, 2, 2, false );
uint32_t rs_c = rs_c0;
uint32_t cs_c = cs_c0;
bli_cgemm_armv7a_ker_2x2( k, alpha, a, b, beta, c, rs_c, cs_c, data ); bli_cgemm_armv7a_ker_2x2( k, alpha, a, b, beta, c, rs_c, cs_c, data );
GEMM_UKR_FLUSH_CT( c );
} }
@@ -153,22 +153,22 @@ void bli_zgemm_armv7a_ker_2x2
void bli_zgemm_armv7a_asm_2x2 void bli_zgemm_armv7a_asm_2x2
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
dcomplex* restrict alpha, dcomplex* restrict alpha,
dcomplex* restrict a, dcomplex* restrict a,
dcomplex* restrict b, dcomplex* restrict b,
dcomplex* restrict beta, dcomplex* restrict beta,
dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, dcomplex* restrict c, inc_t rs_c, inc_t cs_c,
auxinfo_t* restrict data, auxinfo_t* restrict data,
cntx_t* restrict cntx cntx_t* restrict cntx
) )
{ {
// Typecast local copies of integers in case dim_t and inc_t are a // Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions. // different size than is expected by load instructions.
uint32_t k = k0; GEMM_UKR_SETUP_CT_ANY( z, 2, 2, false );
uint32_t rs_c = rs_c0;
uint32_t cs_c = cs_c0;
bli_zgemm_armv7a_ker_2x2( k, alpha, a, b, beta, c, rs_c, cs_c, data ); bli_zgemm_armv7a_ker_2x2( k, alpha, a, b, beta, c, rs_c, cs_c, data );
GEMM_UKR_FLUSH_CT( z );
} }

View File

@@ -37,7 +37,9 @@
void bli_sgemm_armv7a_int_4x4 void bli_sgemm_armv7a_int_4x4
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
float* restrict alpha, float* restrict alpha,
float* restrict a, float* restrict a,
float* restrict b, float* restrict b,
@@ -49,12 +51,14 @@ void bli_sgemm_armv7a_int_4x4
{ {
// Typecast local copies of integers in case dim_t and inc_t are a // Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions. // different size than is expected by load instructions.
uint32_t k_iter = k0 / 4; uint32_t k_iter = k / 4;
uint32_t k_left = k0 % 4; uint32_t k_left = k % 4;
uint32_t rs_c = rs_c0; uint32_t rs_c = rs_c0;
uint32_t cs_c = cs_c0; uint32_t cs_c = cs_c0;
uint32_t i; uint32_t i;
GEMM_UKR_SETUP_CT( s, 4, 4, false );
void* a_next = bli_auxinfo_next_a( data ); void* a_next = bli_auxinfo_next_a( data );
void* b_next = bli_auxinfo_next_b( data ); void* b_next = bli_auxinfo_next_b( data );
@@ -82,47 +86,17 @@ void bli_sgemm_armv7a_int_4x4
if ( *beta != 0.0F ) if ( *beta != 0.0F )
{ {
if ( rs_c == 1 ) // Load column 0
{ cv0 = vld1q_f32( c + 0*cs_c );
// Load column 0
cv0 = vld1q_f32( c + 0*rs_c + 0*cs_c );
// Load column 1 // Load column 1
cv1 = vld1q_f32( c + 0*rs_c + 1*cs_c ); cv1 = vld1q_f32( c + 1*cs_c );
// Load column 2 // Load column 2
cv2 = vld1q_f32( c + 0*rs_c + 2*cs_c ); cv2 = vld1q_f32( c + 2*cs_c );
// Load column 3 // Load column 3
cv3 = vld1q_f32( c + 0*rs_c + 3*cs_c ); cv3 = vld1q_f32( c + 3*cs_c );
}
else
{
// Load column 0
cv0 = vld1q_lane_f32( c + 0*rs_c + 0*cs_c, cv0, 0);
cv0 = vld1q_lane_f32( c + 1*rs_c + 0*cs_c, cv0, 1);
cv0 = vld1q_lane_f32( c + 2*rs_c + 0*cs_c, cv0, 2);
cv0 = vld1q_lane_f32( c + 3*rs_c + 0*cs_c, cv0, 3);
// Load column 1
cv1 = vld1q_lane_f32( c + 0*rs_c + 1*cs_c, cv1, 0);
cv1 = vld1q_lane_f32( c + 1*rs_c + 1*cs_c, cv1, 1);
cv1 = vld1q_lane_f32( c + 2*rs_c + 1*cs_c, cv1, 2);
cv1 = vld1q_lane_f32( c + 3*rs_c + 1*cs_c, cv1, 3);
// Load column 2
cv2 = vld1q_lane_f32( c + 0*rs_c + 2*cs_c, cv2, 0);
cv2 = vld1q_lane_f32( c + 1*rs_c + 2*cs_c, cv2, 1);
cv2 = vld1q_lane_f32( c + 2*rs_c + 2*cs_c, cv2, 2);
cv2 = vld1q_lane_f32( c + 3*rs_c + 2*cs_c, cv2, 3);
// Load column 3
cv3 = vld1q_lane_f32( c + 0*rs_c + 3*cs_c, cv3, 0);
cv3 = vld1q_lane_f32( c + 1*rs_c + 3*cs_c, cv3, 1);
cv3 = vld1q_lane_f32( c + 2*rs_c + 3*cs_c, cv3, 2);
cv3 = vld1q_lane_f32( c + 3*rs_c + 3*cs_c, cv3, 3);
}
} }
else else
{ {
@@ -255,47 +229,22 @@ void bli_sgemm_armv7a_int_4x4
cv3 = vmlaq_f32( cv3, abv3, alphav ); cv3 = vmlaq_f32( cv3, abv3, alphav );
} }
if ( rs_c == 1 ) // Store column 0
{ vst1q_f32( c + 0*cs_c, cv0 );
// Store column 0 // Store column 1
vst1q_f32( c + 0*rs_c + 0*cs_c, cv0 ); vst1q_f32( c + 1*cs_c, cv1 );
// Store column 1 // Store column 2
vst1q_f32( c + 0*rs_c + 1*cs_c, cv1 ); vst1q_f32( c + 2*cs_c, cv2 );
// Store column 2 // Store column 3
vst1q_f32( c + 0*rs_c + 2*cs_c, cv2 ); vst1q_f32( c + 3*cs_c, cv3 );
// Store column 3
vst1q_f32( c + 0*rs_c + 3*cs_c, cv3 );
}
else
{
// Store column 0
vst1q_lane_f32( c + 0*rs_c + 0*cs_c, cv0, 0);
vst1q_lane_f32( c + 1*rs_c + 0*cs_c, cv0, 1);
vst1q_lane_f32( c + 2*rs_c + 0*cs_c, cv0, 2);
vst1q_lane_f32( c + 3*rs_c + 0*cs_c, cv0, 3);
// Store column 1 GEMM_UKR_FLUSH_CT( s );
vst1q_lane_f32( c + 0*rs_c + 1*cs_c, cv1, 0);
vst1q_lane_f32( c + 1*rs_c + 1*cs_c, cv1, 1);
vst1q_lane_f32( c + 2*rs_c + 1*cs_c, cv1, 2);
vst1q_lane_f32( c + 3*rs_c + 1*cs_c, cv1, 3);
// Store column 2
vst1q_lane_f32( c + 0*rs_c + 2*cs_c, cv2, 0);
vst1q_lane_f32( c + 1*rs_c + 2*cs_c, cv2, 1);
vst1q_lane_f32( c + 2*rs_c + 2*cs_c, cv2, 2);
vst1q_lane_f32( c + 3*rs_c + 2*cs_c, cv2, 3);
// Store column 3
vst1q_lane_f32( c + 0*rs_c + 3*cs_c, cv3, 0);
vst1q_lane_f32( c + 1*rs_c + 3*cs_c, cv3, 1);
vst1q_lane_f32( c + 2*rs_c + 3*cs_c, cv3, 2);
vst1q_lane_f32( c + 3*rs_c + 3*cs_c, cv3, 3);
}
} }
void bli_dgemm_armv7a_int_4x4 void bli_dgemm_armv7a_int_4x4
( (
dim_t m,
dim_t n,
dim_t k, dim_t k,
double* restrict alpha, double* restrict alpha,
double* restrict a, double* restrict a,
@@ -314,6 +263,8 @@ void bli_dgemm_armv7a_int_4x4
uint32_t cs_c = cs_c0; uint32_t cs_c = cs_c0;
uint32_t i; uint32_t i;
GEMM_UKR_SETUP_CT_ANY( d, 4, 4, false );
//void* a_next = bli_auxinfo_next_a( data ); //void* a_next = bli_auxinfo_next_a( data );
//void* b_next = bli_auxinfo_next_b( data ); //void* b_next = bli_auxinfo_next_b( data );
@@ -568,5 +519,7 @@ void bli_dgemm_armv7a_int_4x4
*c23 += ab23 * *alpha; *c23 += ab23 * *alpha;
*c33 += ab33 * *alpha; *c33 += ab33 * *alpha;
} }
GEMM_UKR_FLUSH_CT( d );
} }

File diff suppressed because it is too large Load Diff

View File

@@ -56,6 +56,8 @@
void bli_dgemm_bgq_int_8x8 void bli_dgemm_bgq_int_8x8
( (
dim_t m,
dim_t n,
dim_t k, dim_t k,
double* restrict alpha, double* restrict alpha,
double* restrict a, double* restrict a,
@@ -66,6 +68,8 @@ void bli_dgemm_bgq_int_8x8
cntx_t* restrict cntx cntx_t* restrict cntx
) )
{ {
GEMM_UKR_SETUP_CT_ANY( d, 8, 8, false );
//Registers for storing C. //Registers for storing C.
//4 4x4 subblocks of C, c00, c01, c10, c11 //4 4x4 subblocks of C, c00, c01, c10, c11
//4 registers per subblock: a, b, c, d //4 registers per subblock: a, b, c, d
@@ -201,6 +205,8 @@ void bli_dgemm_bgq_int_8x8
UPDATE( AB, c, 0 ); UPDATE( AB, c, 0 );
AB = vec_perm( c11d, c11d, pattern ); AB = vec_perm( c11d, c11d, pattern );
UPDATE( AB, c, 4 ); UPDATE( AB, c, 4 );
GEMM_UKR_FLUSH_CT( d );
} }
void printvec(vector4double v) void printvec(vector4double v)
@@ -214,6 +220,8 @@ void printvec(vector4double v)
void bli_zgemm_bgq_int_4x4 void bli_zgemm_bgq_int_4x4
( (
dim_t m,
dim_t n,
dim_t k, dim_t k,
dcomplex* restrict alpha, dcomplex* restrict alpha,
dcomplex* restrict a, dcomplex* restrict a,
@@ -224,6 +232,8 @@ void bli_zgemm_bgq_int_4x4
cntx_t* restrict cntx cntx_t* restrict cntx
) )
{ {
GEMM_UKR_SETUP_CT_ANY( z, 4, 4, false );
double* a_d = ( double* )a; double* a_d = ( double* )a;
double* b_d = ( double* )b; double* b_d = ( double* )b;
double* c_d = ( double* )c; double* c_d = ( double* )c;
@@ -368,4 +378,6 @@ void bli_zgemm_bgq_int_4x4
c_d += 2*cs_c; c_d += 2*cs_c;
ZUPDATE( c03a, c03b, c_d, 0 ); ZUPDATE( c03a, c03b, c_d, 0 );
ZUPDATE( c13a, c13b, c_d, 4 ); ZUPDATE( c13a, c13b, c_d, 4 );
GEMM_UKR_FLUSH_CT( z );
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -256,6 +256,8 @@ extern int offsets[16];
//#define LOOPMON //#define LOOPMON
void bli_dgemm_knc_asm_30x8 void bli_dgemm_knc_asm_30x8
( (
dim_t m,
dim_t n,
dim_t k, dim_t k,
double* restrict alpha, double* restrict alpha,
double* restrict a, double* restrict a,
@@ -273,6 +275,8 @@ void bli_dgemm_knc_asm_30x8
uint64_t k64 = k; uint64_t k64 = k;
GEMM_UKR_SETUP_CT( d, 30, 8, true );
#ifdef MONITORS #ifdef MONITORS
int toph, topl, both, botl, midl, midh, mid2l, mid2h; int toph, topl, both, botl, midl, midh, mid2l, mid2h;
#endif #endif
@@ -403,15 +407,9 @@ void bli_dgemm_knc_asm_30x8
mov r9, c //load address of c for update mov r9, c //load address of c for update
mov r12, alpha //load address of alpha mov r12, alpha //load address of alpha
// Check if C is row stride. If not, jump to the slow scattered update
mov r14, cs_c
dec r14
jne SCATTEREDUPDATE
mov r14, beta mov r14, beta
vbroadcastsd zmm31, 0[r14] vbroadcastsd zmm31, 0[r14]
vmulpd zmm0, zmm0, 0[r12]{1to8} vmulpd zmm0, zmm0, 0[r12]{1to8}
vmulpd zmm1, zmm1, 0[r12]{1to8} vmulpd zmm1, zmm1, 0[r12]{1to8}
vmulpd zmm2, zmm2, 0[r12]{1to8} vmulpd zmm2, zmm2, 0[r12]{1to8}
@@ -517,47 +515,6 @@ void bli_dgemm_knc_asm_30x8
vmovapd [r9+0], zmm28 vmovapd [r9+0], zmm28
vmovapd [r9+r11+0], zmm29 vmovapd [r9+r11+0], zmm29
jmp END
SCATTEREDUPDATE:
mov r10, offsetPtr
vmovapd zmm31, 0[r10]
vpbroadcastd zmm30, cs_c
mov r13, beta
vpmulld zmm30, zmm31, zmm30
mov ebx, 255
UPDATE_C_ROW_SCATTERED(zmm0, 0, r9)
UPDATE_C_ROW_SCATTERED(zmm1, 1, r9)
UPDATE_C_ROW_SCATTERED(zmm2, 2, r9)
UPDATE_C_ROW_SCATTERED(zmm3, 3, r9)
UPDATE_C_ROW_SCATTERED(zmm4, 4, r9)
UPDATE_C_ROW_SCATTERED(zmm5, 5, r9)
UPDATE_C_ROW_SCATTERED(zmm6, 6, r9)
UPDATE_C_ROW_SCATTERED(zmm7, 7, r9)
UPDATE_C_ROW_SCATTERED(zmm8, 8, r9)
UPDATE_C_ROW_SCATTERED(zmm9, 9, r9)
UPDATE_C_ROW_SCATTERED(zmm10, 10, r9)
UPDATE_C_ROW_SCATTERED(zmm11, 11, r9)
UPDATE_C_ROW_SCATTERED(zmm12, 12, r9)
UPDATE_C_ROW_SCATTERED(zmm13, 13, r9)
UPDATE_C_ROW_SCATTERED(zmm14, 14, r9)
UPDATE_C_ROW_SCATTERED(zmm15, 15, r9)
UPDATE_C_ROW_SCATTERED(zmm16, 16, r9)
UPDATE_C_ROW_SCATTERED(zmm17, 17, r9)
UPDATE_C_ROW_SCATTERED(zmm18, 18, r9)
UPDATE_C_ROW_SCATTERED(zmm19, 19, r9)
UPDATE_C_ROW_SCATTERED(zmm20, 20, r9)
UPDATE_C_ROW_SCATTERED(zmm21, 21, r9)
UPDATE_C_ROW_SCATTERED(zmm22, 22, r9)
UPDATE_C_ROW_SCATTERED(zmm23, 23, r9)
UPDATE_C_ROW_SCATTERED(zmm24, 24, r9)
UPDATE_C_ROW_SCATTERED(zmm25, 25, r9)
UPDATE_C_ROW_SCATTERED(zmm26, 26, r9)
UPDATE_C_ROW_SCATTERED(zmm27, 27, r9)
UPDATE_C_ROW_SCATTERED(zmm28, 28, r9)
UPDATE_C_ROW_SCATTERED(zmm29, 29, r9)
END: END:
#ifdef MONITORS #ifdef MONITORS
rdtsc rdtsc
@@ -566,6 +523,8 @@ void bli_dgemm_knc_asm_30x8
#endif #endif
} }
GEMM_UKR_FLUSH_CT( d );
#ifdef LOOPMON #ifdef LOOPMON
printf("looptime = \t%d\n", bloopl - tloopl); printf("looptime = \t%d\n", bloopl - tloopl);
#endif #endif

View File

@@ -256,6 +256,8 @@ int offsets[16] __attribute__((aligned(0x1000))) = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9
//#define LOOPMON //#define LOOPMON
void bli_sgemm_knc_asm_30x16 void bli_sgemm_knc_asm_30x16
( (
dim_t m,
dim_t n,
dim_t k, dim_t k,
float* restrict alpha, float* restrict alpha,
float* restrict a, float* restrict a,
@@ -273,6 +275,8 @@ void bli_sgemm_knc_asm_30x16
uint64_t k64 = k; uint64_t k64 = k;
GEMM_UKR_SETUP_CT( s, 30, 16, true );
#ifdef MONITORS #ifdef MONITORS
int toph, topl, both, botl, midl, midh, mid2l, mid2h; int toph, topl, both, botl, midl, midh, mid2l, mid2h;
#endif #endif
@@ -403,11 +407,6 @@ void bli_sgemm_knc_asm_30x16
mov r9, c //load address of c for update mov r9, c //load address of c for update
mov r12, alpha //load address of alpha mov r12, alpha //load address of alpha
// Check if C is row stride. If not, jump to the slow scattered update
mov r14, cs_c
dec r14
jne SCATTEREDUPDATE
mov r14, beta mov r14, beta
vbroadcastss zmm31, 0[r14] vbroadcastss zmm31, 0[r14]
@@ -517,48 +516,6 @@ void bli_sgemm_knc_asm_30x16
vmovaps [r9+0], zmm28 vmovaps [r9+0], zmm28
vmovaps [r9+r11+0], zmm29 vmovaps [r9+r11+0], zmm29
jmp END
SCATTEREDUPDATE:
mov r10, offsetPtr
vmovaps zmm31, 0[r10]
vpbroadcastd zmm30, cs_c
mov r13, beta
vpmulld zmm30, zmm31, zmm30
mov ebx, 0xFFFF
UPDATE_C_ROW_SCATTERED(zmm0, 0, r9)
UPDATE_C_ROW_SCATTERED(zmm1, 1, r9)
UPDATE_C_ROW_SCATTERED(zmm2, 2, r9)
UPDATE_C_ROW_SCATTERED(zmm3, 3, r9)
UPDATE_C_ROW_SCATTERED(zmm4, 4, r9)
UPDATE_C_ROW_SCATTERED(zmm5, 5, r9)
UPDATE_C_ROW_SCATTERED(zmm6, 6, r9)
UPDATE_C_ROW_SCATTERED(zmm7, 7, r9)
UPDATE_C_ROW_SCATTERED(zmm8, 8, r9)
UPDATE_C_ROW_SCATTERED(zmm9, 9, r9)
UPDATE_C_ROW_SCATTERED(zmm10, 10, r9)
UPDATE_C_ROW_SCATTERED(zmm11, 11, r9)
UPDATE_C_ROW_SCATTERED(zmm12, 12, r9)
UPDATE_C_ROW_SCATTERED(zmm13, 13, r9)
UPDATE_C_ROW_SCATTERED(zmm14, 14, r9)
UPDATE_C_ROW_SCATTERED(zmm15, 15, r9)
UPDATE_C_ROW_SCATTERED(zmm16, 16, r9)
UPDATE_C_ROW_SCATTERED(zmm17, 17, r9)
UPDATE_C_ROW_SCATTERED(zmm18, 18, r9)
UPDATE_C_ROW_SCATTERED(zmm19, 19, r9)
UPDATE_C_ROW_SCATTERED(zmm20, 20, r9)
UPDATE_C_ROW_SCATTERED(zmm21, 21, r9)
UPDATE_C_ROW_SCATTERED(zmm22, 22, r9)
UPDATE_C_ROW_SCATTERED(zmm23, 23, r9)
UPDATE_C_ROW_SCATTERED(zmm24, 24, r9)
UPDATE_C_ROW_SCATTERED(zmm25, 25, r9)
UPDATE_C_ROW_SCATTERED(zmm26, 26, r9)
UPDATE_C_ROW_SCATTERED(zmm27, 27, r9)
UPDATE_C_ROW_SCATTERED(zmm28, 28, r9)
UPDATE_C_ROW_SCATTERED(zmm29, 29, r9)
END: END:
#ifdef MONITORS #ifdef MONITORS
rdtsc rdtsc
@@ -567,6 +524,8 @@ void bli_sgemm_knc_asm_30x16
#endif #endif
} }
GEMM_UKR_FLUSH_CT( s );
#ifdef LOOPMON #ifdef LOOPMON
printf("looptime = \t%d\n", bloopl - tloopl); printf("looptime = \t%d\n", bloopl - tloopl);
#endif #endif

View File

@@ -185,6 +185,8 @@ static int32_t offsets[32] __attribute__((aligned(64))) =
//#define LOOPMON //#define LOOPMON
void bli_dgemm_knl_asm_24x8 void bli_dgemm_knl_asm_24x8
( (
dim_t m,
dim_t n,
dim_t k_, dim_t k_,
double* restrict alpha, double* restrict alpha,
double* restrict a, double* restrict a,
@@ -201,10 +203,12 @@ void bli_dgemm_knl_asm_24x8
const double * a_next = bli_auxinfo_next_a( data ); const double * a_next = bli_auxinfo_next_a( data );
const double * b_next = bli_auxinfo_next_b( data ); const double * b_next = bli_auxinfo_next_b( data );
const int32_t * offsetPtr = &offsets[0]; int32_t * offsetPtr = &offsets[0];
const int64_t k = k_; int64_t k = k_;
const int64_t rs_c = rs_c_; int64_t rs_c = rs_c_;
const int64_t cs_c = cs_c_; int64_t cs_c = cs_c_;
GEMM_UKR_SETUP_CT( d, 24, 8, true );
#ifdef MONITORS #ifdef MONITORS
int toph, topl, both, botl, midl, midh, mid2l, mid2h; int toph, topl, both, botl, midl, midh, mid2l, mid2h;
@@ -565,10 +569,7 @@ void bli_dgemm_knl_asm_24x8
// Check if C is row stride. If not, jump to the slow scattered update // Check if C is row stride. If not, jump to the slow scattered update
MOV(RAX, VAR(rs_c)) MOV(RAX, VAR(rs_c))
LEA(RAX, MEM(,RAX,8)) LEA(RAX, MEM(,RAX,8))
MOV(RBX, VAR(cs_c))
LEA(RDI, MEM(RAX,RAX,2)) LEA(RDI, MEM(RAX,RAX,2))
CMP(RBX, IMM(1))
JNE(SCATTEREDUPDATE)
VMOVQ(RDX, XMM(1)) VMOVQ(RDX, XMM(1))
SAL(RDX) //shift out sign bit SAL(RDX) //shift out sign bit
@@ -592,74 +593,6 @@ void bli_dgemm_knl_asm_24x8
UPDATE_C_BZ_FOUR_ROWS(24,25,26,27) UPDATE_C_BZ_FOUR_ROWS(24,25,26,27)
UPDATE_C_BZ_FOUR_ROWS(28,29,30,31) UPDATE_C_BZ_FOUR_ROWS(28,29,30,31)
JMP(END)
LABEL(SCATTEREDUPDATE)
MOV(RDI, VAR(offsetPtr))
VMOVAPS(ZMM(2), MEM(RDI))
/* Note that this ignores the upper 32 bits in cs_c */
VPBROADCASTD(ZMM(3), EBX)
VPMULLD(ZMM(2), ZMM(3), ZMM(2))
VMOVQ(RDX, XMM(1))
SAL(RDX) //shift out sign bit
JZ(SCATTERBZ)
UPDATE_C_ROW_SCATTERED( 8)
UPDATE_C_ROW_SCATTERED( 9)
UPDATE_C_ROW_SCATTERED(10)
UPDATE_C_ROW_SCATTERED(11)
UPDATE_C_ROW_SCATTERED(12)
UPDATE_C_ROW_SCATTERED(13)
UPDATE_C_ROW_SCATTERED(14)
UPDATE_C_ROW_SCATTERED(15)
UPDATE_C_ROW_SCATTERED(16)
UPDATE_C_ROW_SCATTERED(17)
UPDATE_C_ROW_SCATTERED(18)
UPDATE_C_ROW_SCATTERED(19)
UPDATE_C_ROW_SCATTERED(20)
UPDATE_C_ROW_SCATTERED(21)
UPDATE_C_ROW_SCATTERED(22)
UPDATE_C_ROW_SCATTERED(23)
UPDATE_C_ROW_SCATTERED(24)
UPDATE_C_ROW_SCATTERED(25)
UPDATE_C_ROW_SCATTERED(26)
UPDATE_C_ROW_SCATTERED(27)
UPDATE_C_ROW_SCATTERED(28)
UPDATE_C_ROW_SCATTERED(29)
UPDATE_C_ROW_SCATTERED(30)
UPDATE_C_ROW_SCATTERED(31)
JMP(END)
LABEL(SCATTERBZ)
UPDATE_C_BZ_ROW_SCATTERED( 8)
UPDATE_C_BZ_ROW_SCATTERED( 9)
UPDATE_C_BZ_ROW_SCATTERED(10)
UPDATE_C_BZ_ROW_SCATTERED(11)
UPDATE_C_BZ_ROW_SCATTERED(12)
UPDATE_C_BZ_ROW_SCATTERED(13)
UPDATE_C_BZ_ROW_SCATTERED(14)
UPDATE_C_BZ_ROW_SCATTERED(15)
UPDATE_C_BZ_ROW_SCATTERED(16)
UPDATE_C_BZ_ROW_SCATTERED(17)
UPDATE_C_BZ_ROW_SCATTERED(18)
UPDATE_C_BZ_ROW_SCATTERED(19)
UPDATE_C_BZ_ROW_SCATTERED(20)
UPDATE_C_BZ_ROW_SCATTERED(21)
UPDATE_C_BZ_ROW_SCATTERED(22)
UPDATE_C_BZ_ROW_SCATTERED(23)
UPDATE_C_BZ_ROW_SCATTERED(24)
UPDATE_C_BZ_ROW_SCATTERED(25)
UPDATE_C_BZ_ROW_SCATTERED(26)
UPDATE_C_BZ_ROW_SCATTERED(27)
UPDATE_C_BZ_ROW_SCATTERED(28)
UPDATE_C_BZ_ROW_SCATTERED(29)
UPDATE_C_BZ_ROW_SCATTERED(30)
UPDATE_C_BZ_ROW_SCATTERED(31)
LABEL(END) LABEL(END)
#ifdef MONITORS #ifdef MONITORS
@@ -701,6 +634,8 @@ void bli_dgemm_knl_asm_24x8
"zmm30", "zmm31", "memory" "zmm30", "zmm31", "memory"
) )
GEMM_UKR_FLUSH_CT( d );
#ifdef LOOPMON #ifdef LOOPMON
printf("looptime = \t%d\n", bloopl - tloopl); printf("looptime = \t%d\n", bloopl - tloopl);
#endif #endif

View File

@@ -182,6 +182,8 @@ static int32_t offsets[32] __attribute__((aligned(64))) =
//#define LOOPMON //#define LOOPMON
void bli_sgemm_knl_asm_24x16 void bli_sgemm_knl_asm_24x16
( (
dim_t m,
dim_t n,
dim_t k_, dim_t k_,
float* restrict alpha, float* restrict alpha,
float* restrict a, float* restrict a,
@@ -198,10 +200,12 @@ void bli_sgemm_knl_asm_24x16
const double * a_next = bli_auxinfo_next_a( data ); const double * a_next = bli_auxinfo_next_a( data );
const double * b_next = bli_auxinfo_next_b( data ); const double * b_next = bli_auxinfo_next_b( data );
const int32_t * offsetPtr = &offsets[0]; int32_t * offsetPtr = &offsets[0];
const int64_t k = k_; int64_t k = k_;
const int64_t rs_c = rs_c_; int64_t rs_c = rs_c_;
const int64_t cs_c = cs_c_; int64_t cs_c = cs_c_;
GEMM_UKR_SETUP_CT( s, 24, 16, true );
#ifdef MONITORS #ifdef MONITORS
int toph, topl, both, botl, midl, midh, mid2l, mid2h; int toph, topl, both, botl, midl, midh, mid2l, mid2h;
@@ -562,10 +566,7 @@ void bli_sgemm_knl_asm_24x16
// Check if C is row stride. If not, jump to the slow scattered update // Check if C is row stride. If not, jump to the slow scattered update
MOV(RAX, VAR(rs_c)) MOV(RAX, VAR(rs_c))
LEA(RAX, MEM(,RAX,4)) LEA(RAX, MEM(,RAX,4))
MOV(RBX, VAR(cs_c))
LEA(RDI, MEM(RAX,RAX,2)) LEA(RDI, MEM(RAX,RAX,2))
CMP(RBX, IMM(1))
JNE(SCATTEREDUPDATE)
VMOVD(EDX, XMM(1)) VMOVD(EDX, XMM(1))
SAL(EDX) //shift out sign bit SAL(EDX) //shift out sign bit
@@ -589,74 +590,6 @@ void bli_sgemm_knl_asm_24x16
UPDATE_C_BZ_FOUR_ROWS(24,25,26,27) UPDATE_C_BZ_FOUR_ROWS(24,25,26,27)
UPDATE_C_BZ_FOUR_ROWS(28,29,30,31) UPDATE_C_BZ_FOUR_ROWS(28,29,30,31)
JMP(END)
LABEL(SCATTEREDUPDATE)
MOV(RDI, VAR(offsetPtr))
VMOVAPS(ZMM(2), MEM(RDI))
/* Note that this ignores the upper 32 bits in cs_c */
VPBROADCASTD(ZMM(3), EBX)
VPMULLD(ZMM(2), ZMM(3), ZMM(2))
VMOVD(EDX, XMM(1))
SAL(EDX) //shift out sign bit
JZ(SCATTERBZ)
UPDATE_C_ROW_SCATTERED( 8)
UPDATE_C_ROW_SCATTERED( 9)
UPDATE_C_ROW_SCATTERED(10)
UPDATE_C_ROW_SCATTERED(11)
UPDATE_C_ROW_SCATTERED(12)
UPDATE_C_ROW_SCATTERED(13)
UPDATE_C_ROW_SCATTERED(14)
UPDATE_C_ROW_SCATTERED(15)
UPDATE_C_ROW_SCATTERED(16)
UPDATE_C_ROW_SCATTERED(17)
UPDATE_C_ROW_SCATTERED(18)
UPDATE_C_ROW_SCATTERED(19)
UPDATE_C_ROW_SCATTERED(20)
UPDATE_C_ROW_SCATTERED(21)
UPDATE_C_ROW_SCATTERED(22)
UPDATE_C_ROW_SCATTERED(23)
UPDATE_C_ROW_SCATTERED(24)
UPDATE_C_ROW_SCATTERED(25)
UPDATE_C_ROW_SCATTERED(26)
UPDATE_C_ROW_SCATTERED(27)
UPDATE_C_ROW_SCATTERED(28)
UPDATE_C_ROW_SCATTERED(29)
UPDATE_C_ROW_SCATTERED(30)
UPDATE_C_ROW_SCATTERED(31)
JMP(END)
LABEL(SCATTERBZ)
UPDATE_C_BZ_ROW_SCATTERED( 8)
UPDATE_C_BZ_ROW_SCATTERED( 9)
UPDATE_C_BZ_ROW_SCATTERED(10)
UPDATE_C_BZ_ROW_SCATTERED(11)
UPDATE_C_BZ_ROW_SCATTERED(12)
UPDATE_C_BZ_ROW_SCATTERED(13)
UPDATE_C_BZ_ROW_SCATTERED(14)
UPDATE_C_BZ_ROW_SCATTERED(15)
UPDATE_C_BZ_ROW_SCATTERED(16)
UPDATE_C_BZ_ROW_SCATTERED(17)
UPDATE_C_BZ_ROW_SCATTERED(18)
UPDATE_C_BZ_ROW_SCATTERED(19)
UPDATE_C_BZ_ROW_SCATTERED(20)
UPDATE_C_BZ_ROW_SCATTERED(21)
UPDATE_C_BZ_ROW_SCATTERED(22)
UPDATE_C_BZ_ROW_SCATTERED(23)
UPDATE_C_BZ_ROW_SCATTERED(24)
UPDATE_C_BZ_ROW_SCATTERED(25)
UPDATE_C_BZ_ROW_SCATTERED(26)
UPDATE_C_BZ_ROW_SCATTERED(27)
UPDATE_C_BZ_ROW_SCATTERED(28)
UPDATE_C_BZ_ROW_SCATTERED(29)
UPDATE_C_BZ_ROW_SCATTERED(30)
UPDATE_C_BZ_ROW_SCATTERED(31)
LABEL(END) LABEL(END)
#ifdef MONITORS #ifdef MONITORS
@@ -698,6 +631,8 @@ void bli_sgemm_knl_asm_24x16
"zmm30", "zmm31", "memory" "zmm30", "zmm31", "memory"
) )
GEMM_UKR_FLUSH_CT( s );
#ifdef LOOPMON #ifdef LOOPMON
printf("looptime = \t%d\n", bloopl - tloopl); printf("looptime = \t%d\n", bloopl - tloopl);
#endif #endif

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -62,12 +62,14 @@
void bli_dgemm_power10_mma_8x8 void bli_dgemm_power10_mma_8x8
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
double* restrict alpha, double* restrict alpha,
double* restrict a, double* restrict a,
double* restrict b, double* restrict b,
double* restrict beta, double* restrict beta,
double* restrict c, inc_t rs_c0, inc_t cs_c0, double* restrict c, inc_t rs_c0, inc_t cs_c,
auxinfo_t* restrict data, auxinfo_t* restrict data,
cntx_t* restrict cntx cntx_t* restrict cntx
) )
@@ -76,11 +78,13 @@ void bli_dgemm_power10_mma_8x8
// Typecast local copies of integers in case dim_t and inc_t are a // Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions. // different size than is expected by load instructions.
// (1 is subtracted from k0 because 1 iteration of the k loop is pulled out) // (1 is subtracted from k0 because 1 iteration of the k loop is pulled out)
uint64_t k_iter = (k0-1) / 4; uint64_t k_iter = (k-1) / 4;
uint64_t k_left = (k0-1) % 4; uint64_t k_left = (k-1) % 4;
uint64_t rs_c = rs_c0; uint64_t rs_c = rs_c0;
GEMM_UKR_SETUP_CT( d, 8, 8, true );
double* restrict A0 = a; double* restrict A0 = a;
double* restrict B0 = b; double* restrict B0 = b;
double* restrict C0 = c; double* restrict C0 = c;
@@ -189,4 +193,5 @@ void bli_dgemm_power10_mma_8x8
SAVE_ACC_bz(dv4sf_t, &acc7, rs_c, 6+4*rs_c); SAVE_ACC_bz(dv4sf_t, &acc7, rs_c, 6+4*rs_c);
} }
GEMM_UKR_FLUSH_CT( d );
} }

View File

@@ -55,7 +55,9 @@
void bli_i16gemm_power10_mma_8x16 void bli_i16gemm_power10_mma_8x16
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
int32_t* restrict alpha, int32_t* restrict alpha,
short* restrict a, short* restrict a,
short* restrict b, short* restrict b,
@@ -66,8 +68,8 @@ void bli_i16gemm_power10_mma_8x16
) )
{ {
uint64_t k_iter = (k0-1) / 4; uint64_t k_iter = (k-1) / 4;
uint64_t k_left = (k0-1) % 4; uint64_t k_left = (k-1) % 4;
uint64_t rs_c = rs_c0; uint64_t rs_c = rs_c0;

View File

@@ -55,7 +55,9 @@
void bli_i16sgemm_power10_mma_8x16 void bli_i16sgemm_power10_mma_8x16
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
int32_t* restrict alpha, int32_t* restrict alpha,
short* restrict a, short* restrict a,
short* restrict b, short* restrict b,
@@ -66,8 +68,8 @@ void bli_i16sgemm_power10_mma_8x16
) )
{ {
uint64_t k_iter = (k0-1) / 4; uint64_t k_iter = (k-1) / 4;
uint64_t k_left = (k0-1) % 4; uint64_t k_left = (k-1) % 4;
uint64_t rs_c = rs_c0; uint64_t rs_c = rs_c0;

View File

@@ -55,7 +55,9 @@
void bli_i4gemm_power10_mma_8x16 void bli_i4gemm_power10_mma_8x16
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
int32_t* restrict alpha, int32_t* restrict alpha,
nibbles* restrict a, nibbles* restrict a,
nibbles* restrict b, nibbles* restrict b,
@@ -66,8 +68,8 @@ void bli_i4gemm_power10_mma_8x16
) )
{ {
uint64_t k_iter = (k0-1) / 4; uint64_t k_iter = (k-1) / 4;
uint64_t k_left = (k0-1) % 4; uint64_t k_left = (k-1) % 4;
uint64_t rs_c = rs_c0; uint64_t rs_c = rs_c0;
@@ -100,19 +102,19 @@ void bli_i4gemm_power10_mma_8x16
I4_INCREMENT I4_INCREMENT
// k loop (unrolled by 4) // k loop (unrolled by 4)
for (int k = 0; k<k_iter; k++) for (int k = 0; k<k_iter; k++)
{ {
I4_AB_PRODUCT I4_AB_PRODUCT
I4_AB_PRODUCT I4_AB_PRODUCT
I4_AB_PRODUCT I4_AB_PRODUCT
I4_AB_PRODUCT I4_AB_PRODUCT
} }
// edge loop // edge loop
for (int k = 0; k<k_left; k++) for (int k = 0; k<k_left; k++)
{ {
I4_AB_PRODUCT I4_AB_PRODUCT
} }
// handle beta cases // handle beta cases
if (beta_ != 0.0) if (beta_ != 0.0)

View File

@@ -55,7 +55,9 @@
void bli_i8gemm_power10_mma_8x16 void bli_i8gemm_power10_mma_8x16
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
int32_t* restrict alpha, int32_t* restrict alpha,
int8_t* restrict a, int8_t* restrict a,
int8_t* restrict b, int8_t* restrict b,
@@ -65,8 +67,8 @@ void bli_i8gemm_power10_mma_8x16
cntx_t* restrict cntx cntx_t* restrict cntx
) )
{ {
uint64_t k_iter = (k0-1) / 4; uint64_t k_iter = (k-1) / 4;
uint64_t k_left = (k0-1) % 4; uint64_t k_left = (k-1) % 4;
uint64_t rs_c = rs_c0; uint64_t rs_c = rs_c0;
@@ -99,19 +101,19 @@ void bli_i8gemm_power10_mma_8x16
I8_INCREMENT I8_INCREMENT
// k loop (unrolled by 4) // k loop (unrolled by 4)
for (int k = 0; k<k_iter; k++) for (int k = 0; k<k_iter; k++)
{ {
I8_AB_PRODUCT I8_AB_PRODUCT
I8_AB_PRODUCT I8_AB_PRODUCT
I8_AB_PRODUCT I8_AB_PRODUCT
I8_AB_PRODUCT I8_AB_PRODUCT
} }
// edge loop // edge loop
for (int k = 0; k<k_left; k++) for (int k = 0; k<k_left; k++)
{ {
I8_AB_PRODUCT I8_AB_PRODUCT
} }
// handle beta cases // handle beta cases
if (beta_ != 0.0) if (beta_ != 0.0)

View File

@@ -56,7 +56,9 @@
void bli_sbgemm_power10_mma_8x16 void bli_sbgemm_power10_mma_8x16
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
float* restrict alpha, float* restrict alpha,
bfloat16* restrict a, bfloat16* restrict a,
bfloat16* restrict b, bfloat16* restrict b,
@@ -67,8 +69,8 @@ void bli_sbgemm_power10_mma_8x16
) )
{ {
uint64_t k_iter = (k0-1)/4; uint64_t k_iter = (k-1)/4;
uint64_t k_left = (k0-1)%4; uint64_t k_left = (k-1)%4;
uint64_t rs_c = rs_c0; uint64_t rs_c = rs_c0;

View File

@@ -55,12 +55,14 @@
void bli_sgemm_power10_mma_8x16 void bli_sgemm_power10_mma_8x16
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
float* restrict alpha, float* restrict alpha,
float* restrict a, float* restrict a,
float* restrict b, float* restrict b,
float* restrict beta, float* restrict beta,
float* restrict c, inc_t rs_c0, inc_t cs_c0, float* restrict c, inc_t rs_c0, inc_t cs_c,
auxinfo_t* restrict data, auxinfo_t* restrict data,
cntx_t* restrict cntx cntx_t* restrict cntx
) )
@@ -68,11 +70,13 @@ void bli_sgemm_power10_mma_8x16
// Typecast local copies of integers in case dim_t and inc_t are a // Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions. // different size than is expected by load instructions.
// (1 is subtracted from k0 because 1 iteration of the k loop is pulled out) // (1 is subtracted from k0 because 1 iteration of the k loop is pulled out)
uint64_t k_iter = (k0-1) / 4; uint64_t k_iter = (k-1) / 4;
uint64_t k_left = (k0-1) % 4; uint64_t k_left = (k-1) % 4;
uint64_t rs_c = rs_c0; uint64_t rs_c = rs_c0;
GEMM_UKR_SETUP_CT( s, 8, 16, true );
fv4sf_t result[4]; fv4sf_t result[4];
fv4sf_t *rowC; fv4sf_t *rowC;
@@ -141,4 +145,6 @@ void bli_sgemm_power10_mma_8x16
SAVE_ACC_bz(fv4sf_t, &acc6, rs_c, 8+4*rs_c); SAVE_ACC_bz(fv4sf_t, &acc6, rs_c, 8+4*rs_c);
SAVE_ACC_bz(fv4sf_t, &acc7, rs_c, 12+4*rs_c); SAVE_ACC_bz(fv4sf_t, &acc7, rs_c, 12+4*rs_c);
} }
GEMM_UKR_FLUSH_CT( s );
} }

View File

@@ -56,7 +56,9 @@
void bli_shgemm_power10_mma_8x16 void bli_shgemm_power10_mma_8x16
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
float* restrict alpha, float* restrict alpha,
float16* restrict a, float16* restrict a,
float16* restrict b, float16* restrict b,
@@ -67,8 +69,8 @@ void bli_shgemm_power10_mma_8x16
) )
{ {
uint64_t k_iter = (k0-1)/4; uint64_t k_iter = (k-1)/4;
uint64_t k_left = (k0-1)%4; uint64_t k_left = (k-1)%4;
uint64_t rs_c = rs_c0; uint64_t rs_c = rs_c0;

View File

@@ -50,30 +50,26 @@
*/ */
void bli_sgemm_power7_int_8x4 void bli_sgemm_power7_int_8x4
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
float* restrict alpha, float* restrict alpha,
float* restrict a, float* restrict a,
float* restrict b, float* restrict b,
float* restrict beta, float* restrict beta,
float* restrict c, inc_t rs_c0, inc_t cs_c0, float* restrict c, inc_t rs_c, inc_t cs_c,
auxinfo_t* restrict data, auxinfo_t* restrict data,
cntx_t* restrict cntx cntx_t* restrict cntx
) )
{ {
// Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions.
uint64_t k = k0;
uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0;
#if 1 || defined(UTEST) #if 1 || defined(UTEST)
const long MR = BLIS_DEFAULT_MR_S, NR = BLIS_DEFAULT_NR_S; const long MR = BLIS_DEFAULT_MR_S, NR = BLIS_DEFAULT_NR_S;
const long LDA = MR, LDB = NR; const long LDA = MR, LDB = NR;
long i, j, kk; long i, j, kk;
float c00; float c00;
for (i=0; i < MR; i++) { for (i=0; i < m; i++) {
for (j=0; j < NR; j++) { for (j=0; j < n; j++) {
c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta; c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta;
for (kk=0; kk < k; kk++) for (kk=0; kk < k; kk++)
c00 += *alpha * (a[COLMAJ_INDEX(i,kk,LDA)] * b[ROWMAJ_INDEX(kk,j,LDB)]); c00 += *alpha * (a[COLMAJ_INDEX(i,kk,LDA)] * b[ROWMAJ_INDEX(kk,j,LDB)]);
@@ -96,24 +92,160 @@ void bli_sgemm_power7_int_8x4
*/ */
void bli_dgemm_power7_int_8x4 void bli_dgemm_power7_int_8x4
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
double* restrict alpha, double* restrict alpha,
double* restrict a, double* restrict a,
double* restrict b, double* restrict b,
double* restrict beta, double* restrict beta,
double* restrict c, inc_t rs_c0, inc_t cs_c0, double* restrict c, inc_t rs_c, inc_t cs_c,
auxinfo_t* restrict data, auxinfo_t* restrict data,
cntx_t* restrict cntx cntx_t* restrict cntx
) )
{ {
// Typecast local copies of integers in case dim_t and inc_t are a if ( cs_c == 1 )
// different size than is expected by load instructions. {
uint64_t k = k0; // Optimized code for case where C rows are contiguous (i.e. C is row-major)
uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0; vector double vzero = vec_splats( 0.0 );
vector double vc00_01 = vzero;
vector double vc02_03 = vzero;
vector double vc10_11 = vzero;
vector double vc12_13 = vzero;
vector double vc20_21 = vzero;
vector double vc22_23 = vzero;
vector double vc30_31 = vzero;
vector double vc32_33 = vzero;
vector double vc40_41 = vzero;
vector double vc42_43 = vzero;
vector double vc50_51 = vzero;
vector double vc52_53 = vzero;
vector double vc60_61 = vzero;
vector double vc62_63 = vzero;
vector double vc70_71 = vzero;
vector double vc72_73 = vzero;
unsigned long long pa = (unsigned long long)a;
unsigned long long pb = (unsigned long long)b;
#if 0
unsigned long long d1 = 1*sizeof(double);
unsigned long long d2 = 2*sizeof(double);
unsigned long long d3 = 3*sizeof(double);
unsigned long long d4 = 4*sizeof(double);
unsigned long long d6 = 6*sizeof(double);
#else
// ppc64 linux abi: r14-r31 Nonvolatile registers used for local variables
register unsigned long long d1 __asm ("r21") = 1*sizeof(double);
register unsigned long long d2 __asm ("r22") = 2*sizeof(double);
register unsigned long long d3 __asm ("r23") = 3*sizeof(double);
register unsigned long long d4 __asm ("r24") = 4*sizeof(double);
register unsigned long long d5 __asm ("r25") = 5*sizeof(double);
register unsigned long long d6 __asm ("r26") = 6*sizeof(double);
register unsigned long long d7 __asm ("r27") = 7*sizeof(double);
__asm__ volatile (";" : "=r" (d1) : "r" (d1) );
__asm__ volatile (";" : "=r" (d2) : "r" (d2) );
__asm__ volatile (";" : "=r" (d3) : "r" (d3) );
__asm__ volatile (";" : "=r" (d4) : "r" (d4) );
__asm__ volatile (";" : "=r" (d5) : "r" (d5) );
__asm__ volatile (";" : "=r" (d6) : "r" (d6) );
__asm__ volatile (";" : "=r" (d7) : "r" (d7) );
#endif
int kk;
for (kk=k; kk > 0; kk--) {
vector double va00 = vec_splats( *(double *)( pa+0 ) );
vector double va10 = vec_splats( *(double *)( pa+d1 ) );
vector double va20 = vec_splats( *(double *)( pa+d2 ) );
vector double va30 = vec_splats( *(double *)( pa+d3 ) );
vector double va40 = vec_splats( *(double *)( pa+d4 ) );
vector double va50 = vec_splats( *(double *)( pa+d5 ) );
vector double va60 = vec_splats( *(double *)( pa+d6 ) );
vector double va70 = vec_splats( *(double *)( pa+d7 ) );
pa += 8*sizeof(double);
vector double vb00_01 = *(vector double *)( pb+0 );
vector double vb02_03 = *(vector double *)( pb+d2 );
pb += 4*sizeof(double);
vc00_01 = vec_madd(va00, vb00_01, vc00_01);
vc02_03 = vec_madd(va00, vb02_03, vc02_03);
vc10_11 = vec_madd(va10, vb00_01, vc10_11);
vc12_13 = vec_madd(va10, vb02_03, vc12_13);
vc20_21 = vec_madd(va20, vb00_01, vc20_21);
vc22_23 = vec_madd(va20, vb02_03, vc22_23);
vc30_31 = vec_madd(va30, vb00_01, vc30_31);
vc32_33 = vec_madd(va30, vb02_03, vc32_33);
vc40_41 = vec_madd(va40, vb00_01, vc40_41);
vc42_43 = vec_madd(va40, vb02_03, vc42_43);
vc50_51 = vec_madd(va50, vb00_01, vc50_51);
vc52_53 = vec_madd(va50, vb02_03, vc52_53);
vc60_61 = vec_madd(va60, vb00_01, vc60_61);
vc62_63 = vec_madd(va60, vb02_03, vc62_63);
vc70_71 = vec_madd(va70, vb00_01, vc70_71);
vc72_73 = vec_madd(va70, vb02_03, vc72_73);
}
vector double valpha = vec_splats( *alpha );
vector double vbeta = (vector double) { *beta, *beta };
vector double *pc = (vector double *)c;
vc00_01 = vec_mul(valpha, vc00_01);
vc02_03 = vec_mul(valpha, vc02_03);
pc[0] = vec_madd( pc[0], vbeta, vc00_01);
pc[1] = vec_madd( pc[1], vbeta, vc02_03);
pc += rs_c/2;
vc10_11 = vec_mul(valpha, vc10_11);
vc12_13 = vec_mul(valpha, vc12_13);
pc[0] = vec_madd( pc[0], vbeta, vc10_11);
pc[1] = vec_madd( pc[1], vbeta, vc12_13);
pc += rs_c/2;
vc20_21 = vec_mul(valpha, vc20_21);
vc22_23 = vec_mul(valpha, vc22_23);
pc[0] = vec_madd( pc[0], vbeta, vc20_21);
pc[1] = vec_madd( pc[1], vbeta, vc22_23);
pc += rs_c/2;
vc30_31 = vec_mul(valpha, vc30_31);
vc32_33 = vec_mul(valpha, vc32_33);
pc[0] = vec_madd( pc[0], vbeta, vc30_31);
pc[1] = vec_madd( pc[1], vbeta, vc32_33);
pc += rs_c/2;
vc40_41 = vec_mul(valpha, vc40_41);
vc42_43 = vec_mul(valpha, vc42_43);
pc[0] = vec_madd( pc[0], vbeta, vc40_41);
pc[1] = vec_madd( pc[1], vbeta, vc42_43);
pc += rs_c/2;
vc50_51 = vec_mul(valpha, vc50_51);
vc52_53 = vec_mul(valpha, vc52_53);
pc[0] = vec_madd( pc[0], vbeta, vc50_51);
pc[1] = vec_madd( pc[1], vbeta, vc52_53);
pc += rs_c/2;
vc60_61 = vec_mul(valpha, vc60_61);
vc62_63 = vec_mul(valpha, vc62_63);
pc[0] = vec_madd( pc[0], vbeta, vc60_61);
pc[1] = vec_madd( pc[1], vbeta, vc62_63);
pc += rs_c/2;
vc70_71 = vec_mul(valpha, vc70_71);
vc72_73 = vec_mul(valpha, vc72_73);
pc[0] = vec_madd( pc[0], vbeta, vc70_71);
pc[1] = vec_madd( pc[1], vbeta, vc72_73);
pc += rs_c/2;
}
else
{
GEMM_UKR_SETUP_CT( d, 8, 4, false );
#if 1
if (rs_c == 1) {
// Optimized code for case where C columns are contiguous (column-major C) // Optimized code for case where C columns are contiguous (column-major C)
vector double vzero = vec_splats( 0.0 ); vector double vzero = vec_splats( 0.0 );
@@ -301,168 +433,8 @@ void bli_dgemm_power7_int_8x4
pc[1] = vec_madd( pc[1], vbeta, vc23_33); pc[1] = vec_madd( pc[1], vbeta, vc23_33);
pc[2] = vec_madd( pc[2], vbeta, vc43_53); pc[2] = vec_madd( pc[2], vbeta, vc43_53);
pc[3] = vec_madd( pc[3], vbeta, vc63_73); pc[3] = vec_madd( pc[3], vbeta, vc63_73);
}
else
#endif
#if 1
if ( cs_c == 1 ) {
// Optimized code for case where C rows are contiguous (i.e. C is row-major)
vector double vzero = vec_splats( 0.0 ); GEMM_UKR_FLUSH_CT( d );
vector double vc00_01 = vzero;
vector double vc02_03 = vzero;
vector double vc10_11 = vzero;
vector double vc12_13 = vzero;
vector double vc20_21 = vzero;
vector double vc22_23 = vzero;
vector double vc30_31 = vzero;
vector double vc32_33 = vzero;
vector double vc40_41 = vzero;
vector double vc42_43 = vzero;
vector double vc50_51 = vzero;
vector double vc52_53 = vzero;
vector double vc60_61 = vzero;
vector double vc62_63 = vzero;
vector double vc70_71 = vzero;
vector double vc72_73 = vzero;
unsigned long long pa = (unsigned long long)a;
unsigned long long pb = (unsigned long long)b;
#if 0
unsigned long long d1 = 1*sizeof(double);
unsigned long long d2 = 2*sizeof(double);
unsigned long long d3 = 3*sizeof(double);
unsigned long long d4 = 4*sizeof(double);
unsigned long long d6 = 6*sizeof(double);
#else
// ppc64 linux abi: r14-r31 Nonvolatile registers used for local variables
register unsigned long long d1 __asm ("r21") = 1*sizeof(double);
register unsigned long long d2 __asm ("r22") = 2*sizeof(double);
register unsigned long long d3 __asm ("r23") = 3*sizeof(double);
register unsigned long long d4 __asm ("r24") = 4*sizeof(double);
register unsigned long long d5 __asm ("r25") = 5*sizeof(double);
register unsigned long long d6 __asm ("r26") = 6*sizeof(double);
register unsigned long long d7 __asm ("r27") = 7*sizeof(double);
__asm__ volatile (";" : "=r" (d1) : "r" (d1) );
__asm__ volatile (";" : "=r" (d2) : "r" (d2) );
__asm__ volatile (";" : "=r" (d3) : "r" (d3) );
__asm__ volatile (";" : "=r" (d4) : "r" (d4) );
__asm__ volatile (";" : "=r" (d5) : "r" (d5) );
__asm__ volatile (";" : "=r" (d6) : "r" (d6) );
__asm__ volatile (";" : "=r" (d7) : "r" (d7) );
#endif
int kk;
for (kk=k; kk > 0; kk--) {
vector double va00 = vec_splats( *(double *)( pa+0 ) );
vector double va10 = vec_splats( *(double *)( pa+d1 ) );
vector double va20 = vec_splats( *(double *)( pa+d2 ) );
vector double va30 = vec_splats( *(double *)( pa+d3 ) );
vector double va40 = vec_splats( *(double *)( pa+d4 ) );
vector double va50 = vec_splats( *(double *)( pa+d5 ) );
vector double va60 = vec_splats( *(double *)( pa+d6 ) );
vector double va70 = vec_splats( *(double *)( pa+d7 ) );
pa += 8*sizeof(double);
vector double vb00_01 = *(vector double *)( pb+0 );
vector double vb02_03 = *(vector double *)( pb+d2 );
pb += 4*sizeof(double);
vc00_01 = vec_madd(va00, vb00_01, vc00_01);
vc02_03 = vec_madd(va00, vb02_03, vc02_03);
vc10_11 = vec_madd(va10, vb00_01, vc10_11);
vc12_13 = vec_madd(va10, vb02_03, vc12_13);
vc20_21 = vec_madd(va20, vb00_01, vc20_21);
vc22_23 = vec_madd(va20, vb02_03, vc22_23);
vc30_31 = vec_madd(va30, vb00_01, vc30_31);
vc32_33 = vec_madd(va30, vb02_03, vc32_33);
vc40_41 = vec_madd(va40, vb00_01, vc40_41);
vc42_43 = vec_madd(va40, vb02_03, vc42_43);
vc50_51 = vec_madd(va50, vb00_01, vc50_51);
vc52_53 = vec_madd(va50, vb02_03, vc52_53);
vc60_61 = vec_madd(va60, vb00_01, vc60_61);
vc62_63 = vec_madd(va60, vb02_03, vc62_63);
vc70_71 = vec_madd(va70, vb00_01, vc70_71);
vc72_73 = vec_madd(va70, vb02_03, vc72_73);
}
vector double valpha = vec_splats( *alpha );
vector double vbeta = (vector double) { *beta, *beta };
vector double *pc = (vector double *)c;
vc00_01 = vec_mul(valpha, vc00_01);
vc02_03 = vec_mul(valpha, vc02_03);
pc[0] = vec_madd( pc[0], vbeta, vc00_01);
pc[1] = vec_madd( pc[1], vbeta, vc02_03);
pc += rs_c/2;
vc10_11 = vec_mul(valpha, vc10_11);
vc12_13 = vec_mul(valpha, vc12_13);
pc[0] = vec_madd( pc[0], vbeta, vc10_11);
pc[1] = vec_madd( pc[1], vbeta, vc12_13);
pc += rs_c/2;
vc20_21 = vec_mul(valpha, vc20_21);
vc22_23 = vec_mul(valpha, vc22_23);
pc[0] = vec_madd( pc[0], vbeta, vc20_21);
pc[1] = vec_madd( pc[1], vbeta, vc22_23);
pc += rs_c/2;
vc30_31 = vec_mul(valpha, vc30_31);
vc32_33 = vec_mul(valpha, vc32_33);
pc[0] = vec_madd( pc[0], vbeta, vc30_31);
pc[1] = vec_madd( pc[1], vbeta, vc32_33);
pc += rs_c/2;
vc40_41 = vec_mul(valpha, vc40_41);
vc42_43 = vec_mul(valpha, vc42_43);
pc[0] = vec_madd( pc[0], vbeta, vc40_41);
pc[1] = vec_madd( pc[1], vbeta, vc42_43);
pc += rs_c/2;
vc50_51 = vec_mul(valpha, vc50_51);
vc52_53 = vec_mul(valpha, vc52_53);
pc[0] = vec_madd( pc[0], vbeta, vc50_51);
pc[1] = vec_madd( pc[1], vbeta, vc52_53);
pc += rs_c/2;
vc60_61 = vec_mul(valpha, vc60_61);
vc62_63 = vec_mul(valpha, vc62_63);
pc[0] = vec_madd( pc[0], vbeta, vc60_61);
pc[1] = vec_madd( pc[1], vbeta, vc62_63);
pc += rs_c/2;
vc70_71 = vec_mul(valpha, vc70_71);
vc72_73 = vec_mul(valpha, vc72_73);
pc[0] = vec_madd( pc[0], vbeta, vc70_71);
pc[1] = vec_madd( pc[1], vbeta, vc72_73);
pc += rs_c/2;
}
else
#endif
{ /* General case. Just do it right. */
#if 1 || defined(UTEST)
const long MR = BLIS_DEFAULT_MR_D, NR = BLIS_DEFAULT_NR_D;
const long LDA = MR, LDB = NR;
int i, j, kk;
double c00;
for (i=0; i < MR; i++) {
for (j=0; j < NR; j++) {
c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta;
for (kk=0; kk < k; kk++)
c00 += *alpha * (a[COLMAJ_INDEX(i,kk,LDA)] * b[ROWMAJ_INDEX(kk,j,LDB)]);
c[BLIS_INDEX(i,j,rs_c,cs_c)] = c00;
}
}
#else
//BLIS_DGEMM_UKERNEL_REF(k, alpha, a, b, beta, c, rs_c, cs_c, data);
#endif
} }
} }
@@ -477,30 +449,26 @@ void bli_dgemm_power7_int_8x4
*/ */
void bli_cgemm_power7_int_8x4 void bli_cgemm_power7_int_8x4
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
scomplex* restrict alpha, scomplex* restrict alpha,
scomplex* restrict a, scomplex* restrict a,
scomplex* restrict b, scomplex* restrict b,
scomplex* restrict beta, scomplex* restrict beta,
scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, scomplex* restrict c, inc_t rs_c, inc_t cs_c,
auxinfo_t* restrict data, auxinfo_t* restrict data,
cntx_t* restrict cntx cntx_t* restrict cntx
) )
{ {
// Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions.
uint64_t k = k0;
uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0;
#if 1 || defined(UTEST) #if 1 || defined(UTEST)
const long MR = BLIS_DEFAULT_MR_C, NR = BLIS_DEFAULT_NR_C; const long MR = BLIS_DEFAULT_MR_C, NR = BLIS_DEFAULT_NR_C;
const long LDA = MR, LDB = NR; const long LDA = MR, LDB = NR;
int i, j, kk; int i, j, kk;
scomplex c00; scomplex c00;
for (i=0; i < MR; i++) { for (i=0; i < m; i++) {
for (j=0; j < NR; j++) { for (j=0; j < n; j++) {
scomplex tmpc, tmpa, tmpb, tmp; scomplex tmpc, tmpa, tmpb, tmp;
//c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta; //c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta;
tmpc = c[BLIS_INDEX(i,j,rs_c,cs_c)]; tmpc = c[BLIS_INDEX(i,j,rs_c,cs_c)];
@@ -534,30 +502,26 @@ void bli_cgemm_power7_int_8x4
*/ */
void bli_zgemm_power7_int_8x4 void bli_zgemm_power7_int_8x4
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
scomplex* restrict alpha, scomplex* restrict alpha,
scomplex* restrict a, scomplex* restrict a,
scomplex* restrict b, scomplex* restrict b,
scomplex* restrict beta, scomplex* restrict beta,
scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, scomplex* restrict c, inc_t rs_c, inc_t cs_c,
auxinfo_t* restrict data, auxinfo_t* restrict data,
cntx_t* restrict cntx cntx_t* restrict cntx
) )
{ {
// Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions.
uint64_t k = k0;
uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0;
#if 1 || defined(UTEST) #if 1 || defined(UTEST)
const long MR = BLIS_DEFAULT_MR_Z, NR = BLIS_DEFAULT_NR_Z; const long MR = BLIS_DEFAULT_MR_Z, NR = BLIS_DEFAULT_NR_Z;
const long LDA = MR, LDB = NR; const long LDA = MR, LDB = NR;
int i, j, kk; int i, j, kk;
dcomplex c00; dcomplex c00;
for (i=0; i < MR; i++) { for (i=0; i < m; i++) {
for (j=0; j < NR; j++) { for (j=0; j < n; j++) {
dcomplex tmpc, tmpa, tmpb, tmp; dcomplex tmpc, tmpa, tmpb, tmp;
//c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta; //c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta;
tmpc = c[BLIS_INDEX(i,j,rs_c,cs_c)]; tmpc = c[BLIS_INDEX(i,j,rs_c,cs_c)];

View File

@@ -43,6 +43,8 @@
void bli_sgemm_opt_8x4 void bli_sgemm_opt_8x4
( (
dim_t m,
dim_t n,
dim_t k, dim_t k,
float* restrict alpha, float* restrict alpha,
float* restrict a, float* restrict a,
@@ -55,6 +57,8 @@ void bli_sgemm_opt_8x4
void bli_dgemm_opt_8x4 void bli_dgemm_opt_8x4
( (
dim_t m,
dim_t n,
dim_t k, dim_t k,
double* restrict alpha, double* restrict alpha,
double* restrict a, double* restrict a,
@@ -67,6 +71,8 @@ void bli_dgemm_opt_8x4
void bli_cgemm_opt_8x4 void bli_cgemm_opt_8x4
( (
dim_t m,
dim_t n,
dim_t k, dim_t k,
scomplex* restrict alpha, scomplex* restrict alpha,
scomplex* restrict a, scomplex* restrict a,
@@ -79,6 +85,8 @@ void bli_cgemm_opt_8x4
void bli_zgemm_opt_8x4 void bli_zgemm_opt_8x4
( (
dim_t m,
dim_t n,
dim_t k, dim_t k,
dcomplex* restrict alpha, dcomplex* restrict alpha,
dcomplex* restrict a, dcomplex* restrict a,

View File

@@ -37,7 +37,9 @@
void bli_dgemm_power9_asm_12x6 void bli_dgemm_power9_asm_12x6
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
double* restrict alpha, double* restrict alpha,
double* restrict a, double* restrict a,
double* restrict b, double* restrict b,
@@ -50,117 +52,91 @@ void bli_dgemm_power9_asm_12x6
// Typecast local copies of integers in case dim_t and inc_t are a // Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions. // different size than is expected by load instructions.
uint64_t k_iter = k0 / 16; uint64_t k_iter = k / 16;
uint64_t k_left = k0 % 16; uint64_t k_left = k % 16;
uint64_t rs_c = rs_c0; uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0; uint64_t cs_c = cs_c0;
GEMM_UKR_SETUP_CT( d, 12, 6, false );
__asm__ volatile __asm__ volatile
( (
" \n\t" " \n\t"
"ld %%r7, %2 \n\t" // load ptr of A "ld %%r7, %2 \n\t" // load ptr of A
"ld %%r8, %3 \n\t" // load ptr of B "ld %%r8, %3 \n\t" // load ptr of B
"ld %%r16, %6 \n\t" // load ptr of C "ld %%r16, %6 \n\t" // load ptr of C
" \n\t" " \n\t"
"ld %%r28, %4 \n\t" // load ptr for alpha "ld %%r28, %4 \n\t" // load ptr for alpha
"ld %%r29, %5 \n\t" // load ptr for beta "ld %%r29, %5 \n\t" // load ptr for beta
" \n\t" " \n\t"
"ld %%r11, %0 \n\t" // load k_iter "ld %%r11, %0 \n\t" // load k_iter
"ld %%r12, %1 \n\t" // load k_left "ld %%r12, %1 \n\t" // load k_left
" \n\t" " \n\t"
"ld %%r10, %8 \n\t" // load cs_c "ld %%r10, %8 \n\t" // load cs_c
"slwi %%r10, %%r10, 3 \n\t" // mul by size of elem "slwi %%r10, %%r10, 3 \n\t" // mul by size of elem
" \n\t" " \n\t"
"ld %%r9, %7 \n\t" // load rs_c "ld %%r9, %7 \n\t" // load rs_c
"slwi %%r9, %%r9, 3 \n\t" // mul by size of elem "slwi %%r9, %%r9, 3 \n\t" // mul by size of elem
" \n\t" " \n\t"
"ld %%r26, 0(%%r29) \n\t" // load val of beta "ld %%r26, 0(%%r29) \n\t" // load val of beta
" \n\t" " \n\t"
"lxvdsx %%vs62, 0, %%r28 \n\t" // splat alpha "lxvdsx %%vs62, 0, %%r28 \n\t" // splat alpha
"lxvdsx %%vs63, 0, %%r29 \n\t" // splat beta "lxvdsx %%vs63, 0, %%r29 \n\t" // splat beta
" \n\t" " \n\t"
"add %%r17, %%r16, %%r10 \n\t" // addr of col 1 of C "add %%r17, %%r16, %%r10 \n\t" // addr of col 1 of C
"add %%r18, %%r17, %%r10 \n\t" // col 2 of C "add %%r18, %%r17, %%r10 \n\t" // col 2 of C
"add %%r19, %%r18, %%r10 \n\t" // col 3 of C "add %%r19, %%r18, %%r10 \n\t" // col 3 of C
"add %%r20, %%r19, %%r10 \n\t" // col 4 of C "add %%r20, %%r19, %%r10 \n\t" // col 4 of C
"add %%r21, %%r20, %%r10 \n\t" // col 5 of C "add %%r21, %%r20, %%r10 \n\t" // col 5 of C
" \n\t" " \n\t"
DZERO_OUT_VREG DZERO_OUT_VREG
" \n\t" " \n\t"
DPRELOAD DPRELOAD
" \n\t" " \n\t"
"addi %%r8, %%r8, 96 \n\t" // move to next col/row of A/B "addi %%r8, %%r8, 96 \n\t" // move to next col/row of A/B
"addi %%r7, %%r7, 96 \n\t" "addi %%r7, %%r7, 96 \n\t"
" \n\t" " \n\t"
DPREFETCH DPREFETCH
" \n\t" " \n\t"
"cmpwi %%r11, 0 \n\t" // if k_iter == 0, "cmpwi %%r11, 0 \n\t" // if k_iter == 0,
"beq DCONSIDERKLEFT \n\t" // then jmp to k_left "beq DCONSIDERKLEFT \n\t" // then jmp to k_left
"mtctr %%r11 \n\t" // else, do k_iter loop "mtctr %%r11 \n\t" // else, do k_iter loop
" \n\t" " \n\t"
"DLOOPKITER: \n\t" // k_iter loop "DLOOPKITER: \n\t" // k_iter loop
" \n\t" " \n\t"
A_B_PRODUCT_16 // compute A*B A_B_PRODUCT_16 // compute A*B
" \n\t" " \n\t"
"bdnz DLOOPKITER \n\t" "bdnz DLOOPKITER \n\t"
" \n\t" " \n\t"
"DCONSIDERKLEFT: \n\t" "DCONSIDERKLEFT: \n\t"
" \n\t" " \n\t"
"cmpwi %%r12, 0 \n\t" // if k_left == 0, "cmpwi %%r12, 0 \n\t" // if k_left == 0,
"beq DPOSTACCUM \n\t" // then jmp to post accum "beq DPOSTACCUM \n\t" // then jmp to post accum
"mtctr %%r12 \n\t" // else, do k_left loop "mtctr %%r12 \n\t" // else, do k_left loop
" \n\t" " \n\t"
"DLOOPKLEFT: \n\t" // k_left loop "DLOOPKLEFT: \n\t" // k_left loop
" \n\t" " \n\t"
A_B_PRODUCT_1 A_B_PRODUCT_1
" \n\t" " \n\t"
"bdnz DLOOPKLEFT \n\t" "bdnz DLOOPKLEFT \n\t"
" \n\t" " \n\t"
"DPOSTACCUM: \n\t" "DPOSTACCUM: \n\t"
" \n\t" " \n\t"
DSCALE_ALPHA DSCALE_ALPHA
" \n\t" " \n\t"
"cmpdi %%r26, 0 \n\t" // if beta == 0, "cmpdi %%r26, 0 \n\t" // if beta == 0,
"beq DBETAZERO \n\t" // then jmp to BZ "beq DBETAZERO \n\t" // then jmp to BZ
" \n\t" " \n\t"
"cmpwi %%r9, 8 \n\t" // if rs_c == 8 DCOL_SCALE_BETA
"beq DCOLSTOREDBNZ \n\t" // then jmp to col store " \n\t"
" \n\t" "DBETAZERO: \n\t" // BZ case
"DGENSTOREDBNZ: \n\t" // BNZ gen stored case " \n\t"
" \n\t" DCOL_STORE
DGEN_LOAD_OFS_C " \n\t"
" \n\t" "DDONE: \n\t"
DGEN_SCALE_BETA " \n\t"
" \n\t" : // output operands (none)
"b DGENSTORED \n\t"
" \n\t"
"DCOLSTOREDBNZ: \n\t" // BNZ col stored case
" \n\t"
DCOL_SCALE_BETA
" \n\t"
"b DCOLSTORED \n\t"
" \n\t"
"DBETAZERO: \n\t" // BZ case
" \n\t"
"cmpwi %%r9, 8 \n\t" // if rs_c == 8,
"beq DCOLSTORED \n\t" // C is col stored
" \n\t"
"DGENSTORED: \n\t" // BZ gen stored case
" \n\t"
DGEN_LOAD_OFS_C
" \n\t"
DGEN_STORE
" \n\t"
"b DDONE \n\t"
" \n\t"
"DCOLSTORED: \n\t" // BZ col stored case
" \n\t"
DCOL_STORE
" \n\t"
"DDONE: \n\t"
" \n\t"
: // output operands (none)
: // input operands : // input operands
"m" (k_iter), // 0 "m" (k_iter), // 0
"m" (k_left), // 1 "m" (k_left), // 1
@@ -174,28 +150,30 @@ void bli_dgemm_power9_asm_12x6
"m" (b_next), // 9 "m" (b_next), // 9
"m" (a_next)*/ // 10 "m" (a_next)*/ // 10
: // register clobber list : // register clobber list
/* unclobberable regs: r2, r3, r4, r5, r6, r13, r14, r15, r30, r31 */ /* unclobberable regs: r2, r3, r4, r5, r6, r13, r14, r15, r30, r31 */
"r0", "r7", "r8", "r9", "r0", "r7", "r8", "r9",
"r10", "r11", "r12", "r16", "r17", "r18", "r19", "r10", "r11", "r12", "r16", "r17", "r18", "r19",
"r20", "r21", "r22", "r23", "r24", "r25", "r26", "r27", "r28", "r29" "r20", "r21", "r22", "r23", "r24", "r25", "r26", "r27", "r28", "r29"
#if XLC #if XLC
,"f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9" ,"f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9"
, "f10", "f11", "f12", "f13", "f14", "f15", "f16", "f17", "f18", "f19" , "f10", "f11", "f12", "f13", "f14", "f15", "f16", "f17", "f18", "f19"
, "f20" ,"f21", "f22", "f23", "f24", "f25", "f26", "f27", "f28", "f29" , "f20" ,"f21", "f22", "f23", "f24", "f25", "f26", "f27", "f28", "f29"
, "f30" ,"f31" , "f30" ,"f31"
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9" , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9"
, "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19" , "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19"
, "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29" , "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29"
, "v30", "v31" , "v30", "v31"
#else #else
, "vs0", "vs1", "vs2", "vs3", "vs4", "vs5", "vs6", "vs7", "vs8", "vs9" , "vs0", "vs1", "vs2", "vs3", "vs4", "vs5", "vs6", "vs7", "vs8", "vs9"
, "vs10", "vs11", "vs12", "vs13", "vs14", "vs15", "vs16", "vs17", "vs18", "vs19" , "vs10", "vs11", "vs12", "vs13", "vs14", "vs15", "vs16", "vs17", "vs18", "vs19"
, "vs20", "vs21", "vs22", "vs23", "vs24", "vs25", "vs26", "vs27", "vs28", "vs29" , "vs20", "vs21", "vs22", "vs23", "vs24", "vs25", "vs26", "vs27", "vs28", "vs29"
, "vs30", "vs31", "vs32", "vs33", "vs34", "vs35", "vs36", "vs37", "vs38", "vs39" , "vs30", "vs31", "vs32", "vs33", "vs34", "vs35", "vs36", "vs37", "vs38", "vs39"
, "vs40", "vs41", "vs42", "vs43", "vs44", "vs45", "vs46", "vs47", "vs48", "vs49" , "vs40", "vs41", "vs42", "vs43", "vs44", "vs45", "vs46", "vs47", "vs48", "vs49"
, "vs50", "vs51", "vs52", "vs53" , "vs50", "vs51", "vs52", "vs53"
#endif #endif
); );
GEMM_UKR_FLUSH_CT( d );
} }

File diff suppressed because it is too large Load Diff

View File

@@ -32,6 +32,7 @@
*/ */
#include <emmintrin.h>
#include <immintrin.h> #include <immintrin.h>
#include "blis.h" #include "blis.h"
@@ -39,7 +40,9 @@
#if 0 #if 0
void bli_sgemm_sandybridge_int_8x8 void bli_sgemm_sandybridge_int_8x8
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
float* restrict alpha, float* restrict alpha,
float* restrict a, float* restrict a,
float* restrict b, float* restrict b,
@@ -52,11 +55,11 @@ void bli_sgemm_sandybridge_int_8x8
} }
#endif #endif
void bli_dgemm_sandybridge_int_8x4 void bli_dgemm_sandybridge_int_8x4
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
double* restrict alpha, double* restrict alpha,
double* restrict a, double* restrict a,
double* restrict b, double* restrict b,
@@ -66,19 +69,22 @@ void bli_dgemm_sandybridge_int_8x4
cntx_t* restrict cntx cntx_t* restrict cntx
) )
{ {
//void* a_next = bli_auxinfo_next_a( data ); //void* a_next = bli_auxinfo_next_a( data );
void* b_next = bli_auxinfo_next_b( data ); void* b_next = bli_auxinfo_next_b( data );
// Typecast local copies of integers in case dim_t and inc_t are a // Typecast local copies of integers in case dim_t and inc_t are a
// different size than is expected by load instructions. // different size than is expected by load instructions.
uint64_t k_iter = k0 / 2; uint64_t k_iter = k / 2;
uint64_t k_left = k0 % 2; uint64_t k_left = k % 2;
uint64_t rs_c = rs_c0; uint64_t rs_c = rs_c0;
uint64_t cs_c = cs_c0; uint64_t cs_c = cs_c0;
uint64_t i; uint64_t i;
double *c00, *c01, *c02, *c03; GEMM_UKR_SETUP_CT( d, 8, 4, false );
double *c40, *c41, *c42, *c43;
double *c00, *c01, *c02, *c03;
double *c40, *c41, *c42, *c43;
// Quad registers. // Quad registers.
__m256d va0_3, va4_7; __m256d va0_3, va4_7;
@@ -97,13 +103,10 @@ void bli_dgemm_sandybridge_int_8x4
__m256d va0_3b2, va4_7b2; __m256d va0_3b2, va4_7b2;
__m256d va0_3b3, va4_7b3; __m256d va0_3b3, va4_7b3;
__m256d valpha, vbeta, vtmp; __m256d valpha, vbeta, vtmp;
__m256d vc0_3_0, vc0_3_1, vc0_3_2, vc0_3_3; __m256d vc0_3_0, vc0_3_1, vc0_3_2, vc0_3_3;
__m256d vc4_7_0, vc4_7_1, vc4_7_2, vc4_7_3; __m256d vc4_7_0, vc4_7_1, vc4_7_2, vc4_7_3;
__m128d aa, bb;
__asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"(a) ); __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"(a) );
__asm__ volatile( "prefetcht2 0(%0) \n\t" : :"r"(b_next) ); __asm__ volatile( "prefetcht2 0(%0) \n\t" : :"r"(b_next) );
__asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"(c) ); __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"(c) );
@@ -129,19 +132,19 @@ void bli_dgemm_sandybridge_int_8x4
va4_7b_3 = _mm256_setzero_pd(); va4_7b_3 = _mm256_setzero_pd();
// Load va0_3 // Load va0_3
va0_3 = _mm256_load_pd( a ); va0_3 = _mm256_load_pd( a );
// Load va4_7 // Load va4_7
va4_7 = _mm256_load_pd( a + 4 ); va4_7 = _mm256_load_pd( a + 4 );
// Load vb (b0,b1,b2,b3) // Load vb (b0,b1,b2,b3)
vb0 = _mm256_load_pd( b ); vb0 = _mm256_load_pd( b );
for( i = 0; i < k_iter; ++i ) for( i = 0; i < k_iter; ++i )
{ {
__asm__ volatile( "prefetcht0 192(%0) \n\t" : :"r"(a) ); __asm__ volatile( "prefetcht0 192(%0) \n\t" : :"r"(a) );
// Load va0_3 (Prefetch) // Load va0_3 (Prefetch)
vA0_3 = _mm256_load_pd( a + 8 ); vA0_3 = _mm256_load_pd( a + 8 );
// Iteration 0. // Iteration 0.
vtmp = _mm256_mul_pd( va0_3, vb0 ); vtmp = _mm256_mul_pd( va0_3, vb0 );
@@ -151,10 +154,10 @@ void bli_dgemm_sandybridge_int_8x4
va4_7b_0 = _mm256_add_pd( va4_7b_0, vtmp ); va4_7b_0 = _mm256_add_pd( va4_7b_0, vtmp );
// Load va4_7 (Prefetch) // Load va4_7 (Prefetch)
vA4_7 = _mm256_load_pd( a + 12 ); vA4_7 = _mm256_load_pd( a + 12 );
// Shuffle vb (b1,b0,b3,b2) // Shuffle vb (b1,b0,b3,b2)
vb1 = _mm256_shuffle_pd( vb0, vb0, 0x5 ); vb1 = _mm256_shuffle_pd( vb0, vb0, 0x5 );
vtmp = _mm256_mul_pd( va0_3, vb1 ); vtmp = _mm256_mul_pd( va0_3, vb1 );
va0_3b_1 = _mm256_add_pd( va0_3b_1, vtmp ); va0_3b_1 = _mm256_add_pd( va0_3b_1, vtmp );
@@ -163,10 +166,10 @@ void bli_dgemm_sandybridge_int_8x4
va4_7b_1 = _mm256_add_pd( va4_7b_1, vtmp ); va4_7b_1 = _mm256_add_pd( va4_7b_1, vtmp );
// Permute vb (b3,b2,b1,b0) // Permute vb (b3,b2,b1,b0)
vb2 = _mm256_permute2f128_pd( vb1, vb1, 0x1 ); vb2 = _mm256_permute2f128_pd( vb1, vb1, 0x1 );
// Load vb (b0,b1,b2,b3) (Prefetch) // Load vb (b0,b1,b2,b3) (Prefetch)
vB0 = _mm256_load_pd( b + 4 ); vB0 = _mm256_load_pd( b + 4 );
vtmp = _mm256_mul_pd( va0_3, vb2 ); vtmp = _mm256_mul_pd( va0_3, vb2 );
va0_3b_2 = _mm256_add_pd( va0_3b_2, vtmp ); va0_3b_2 = _mm256_add_pd( va0_3b_2, vtmp );
@@ -175,7 +178,7 @@ void bli_dgemm_sandybridge_int_8x4
va4_7b_2 = _mm256_add_pd( va4_7b_2, vtmp ); va4_7b_2 = _mm256_add_pd( va4_7b_2, vtmp );
// Shuffle vb (b3,b2,b1,b0) // Shuffle vb (b3,b2,b1,b0)
vb3 = _mm256_shuffle_pd( vb2, vb2, 0x5 ); vb3 = _mm256_shuffle_pd( vb2, vb2, 0x5 );
vtmp = _mm256_mul_pd( va0_3, vb3 ); vtmp = _mm256_mul_pd( va0_3, vb3 );
va0_3b_3 = _mm256_add_pd( va0_3b_3, vtmp ); va0_3b_3 = _mm256_add_pd( va0_3b_3, vtmp );
@@ -188,12 +191,12 @@ void bli_dgemm_sandybridge_int_8x4
__asm__ volatile( "prefetcht0 512(%0) \n\t" : :"r"(a) ); __asm__ volatile( "prefetcht0 512(%0) \n\t" : :"r"(a) );
// Load va0_3 (Next iteration) // Load va0_3 (Next iteration)
va0_3 = _mm256_load_pd( a + 16 ); va0_3 = _mm256_load_pd( a + 16 );
vtmp = _mm256_mul_pd( vA0_3, vB0 ); vtmp = _mm256_mul_pd( vA0_3, vB0 );
va0_3b_0 = _mm256_add_pd( va0_3b_0, vtmp ); va0_3b_0 = _mm256_add_pd( va0_3b_0, vtmp );
vb1 = _mm256_shuffle_pd( vB0, vB0, 0x5 ); vb1 = _mm256_shuffle_pd( vB0, vB0, 0x5 );
vtmp = _mm256_mul_pd( vA4_7, vB0 ); vtmp = _mm256_mul_pd( vA4_7, vB0 );
va4_7b_0 = _mm256_add_pd( va4_7b_0, vtmp ); va4_7b_0 = _mm256_add_pd( va4_7b_0, vtmp );
@@ -202,9 +205,9 @@ void bli_dgemm_sandybridge_int_8x4
va0_3b_1 = _mm256_add_pd( va0_3b_1, vtmp ); va0_3b_1 = _mm256_add_pd( va0_3b_1, vtmp );
// Load va4_7 (Next iteration) // Load va4_7 (Next iteration)
va4_7 = _mm256_load_pd( a + 20 ); va4_7 = _mm256_load_pd( a + 20 );
vb2 = _mm256_permute2f128_pd( vb1, vb1, 0x1 ); vb2 = _mm256_permute2f128_pd( vb1, vb1, 0x1 );
vtmp = _mm256_mul_pd( vA4_7, vb1 ); vtmp = _mm256_mul_pd( vA4_7, vb1 );
va4_7b_1 = _mm256_add_pd( va4_7b_1, vtmp ); va4_7b_1 = _mm256_add_pd( va4_7b_1, vtmp );
@@ -212,13 +215,13 @@ void bli_dgemm_sandybridge_int_8x4
vtmp = _mm256_mul_pd( vA0_3, vb2 ); vtmp = _mm256_mul_pd( vA0_3, vb2 );
va0_3b_2 = _mm256_add_pd( va0_3b_2, vtmp ); va0_3b_2 = _mm256_add_pd( va0_3b_2, vtmp );
vb3 = _mm256_shuffle_pd( vb2, vb2, 0x5 ); vb3 = _mm256_shuffle_pd( vb2, vb2, 0x5 );
vtmp = _mm256_mul_pd( vA4_7, vb2 ); vtmp = _mm256_mul_pd( vA4_7, vb2 );
va4_7b_2 = _mm256_add_pd( va4_7b_2, vtmp ); va4_7b_2 = _mm256_add_pd( va4_7b_2, vtmp );
// Load vb0(Next iteration) // Load vb0(Next iteration)
vb0 = _mm256_load_pd( b + 8 ); vb0 = _mm256_load_pd( b + 8 );
vtmp = _mm256_mul_pd( vA0_3, vb3 ); vtmp = _mm256_mul_pd( vA0_3, vb3 );
va0_3b_3 = _mm256_add_pd( va0_3b_3, vtmp ); va0_3b_3 = _mm256_add_pd( va0_3b_3, vtmp );
@@ -236,12 +239,12 @@ void bli_dgemm_sandybridge_int_8x4
// Iteration 0. // Iteration 0.
// Load va0_3 // Load va0_3
va0_3 = _mm256_load_pd( a ); va0_3 = _mm256_load_pd( a );
// Load va4_7 // Load va4_7
va4_7 = _mm256_load_pd( a + 4 ); va4_7 = _mm256_load_pd( a + 4 );
// Load vb (b0,b1,b2,b3) // Load vb (b0,b1,b2,b3)
vb = _mm256_load_pd( b ); vb = _mm256_load_pd( b );
vtmp = _mm256_mul_pd( va0_3, vb ); vtmp = _mm256_mul_pd( va0_3, vb );
va0_3b_0 = _mm256_add_pd( va0_3b_0, vtmp ); va0_3b_0 = _mm256_add_pd( va0_3b_0, vtmp );
@@ -250,7 +253,7 @@ void bli_dgemm_sandybridge_int_8x4
va4_7b_0 = _mm256_add_pd( va4_7b_0, vtmp ); va4_7b_0 = _mm256_add_pd( va4_7b_0, vtmp );
// Shuffle vb (b1,b0,b3,b2) // Shuffle vb (b1,b0,b3,b2)
vb = _mm256_shuffle_pd( vb, vb, 0x5 ); vb = _mm256_shuffle_pd( vb, vb, 0x5 );
vtmp = _mm256_mul_pd( va0_3, vb ); vtmp = _mm256_mul_pd( va0_3, vb );
va0_3b_1 = _mm256_add_pd( va0_3b_1, vtmp ); va0_3b_1 = _mm256_add_pd( va0_3b_1, vtmp );
@@ -259,7 +262,7 @@ void bli_dgemm_sandybridge_int_8x4
va4_7b_1 = _mm256_add_pd( va4_7b_1, vtmp ); va4_7b_1 = _mm256_add_pd( va4_7b_1, vtmp );
// Permute vb (b3,b2,b1,b0) // Permute vb (b3,b2,b1,b0)
vb = _mm256_permute2f128_pd( vb, vb, 0x1 ); vb = _mm256_permute2f128_pd( vb, vb, 0x1 );
vtmp = _mm256_mul_pd( va0_3, vb ); vtmp = _mm256_mul_pd( va0_3, vb );
va0_3b_2 = _mm256_add_pd( va0_3b_2, vtmp ); va0_3b_2 = _mm256_add_pd( va0_3b_2, vtmp );
@@ -268,7 +271,7 @@ void bli_dgemm_sandybridge_int_8x4
va4_7b_2 = _mm256_add_pd( va4_7b_2, vtmp ); va4_7b_2 = _mm256_add_pd( va4_7b_2, vtmp );
// Shuffle vb (b3,b2,b1,b0) // Shuffle vb (b3,b2,b1,b0)
vb = _mm256_shuffle_pd( vb, vb, 0x5 ); vb = _mm256_shuffle_pd( vb, vb, 0x5 );
vtmp = _mm256_mul_pd( va0_3, vb ); vtmp = _mm256_mul_pd( va0_3, vb );
va0_3b_3 = _mm256_add_pd( va0_3b_3, vtmp ); va0_3b_3 = _mm256_add_pd( va0_3b_3, vtmp );
@@ -309,12 +312,72 @@ void bli_dgemm_sandybridge_int_8x4
va4_7b1 = _mm256_permute2f128_pd( vtmpa_4_7b_1, vtmpa_4_7b_3, 0x30 ); va4_7b1 = _mm256_permute2f128_pd( vtmpa_4_7b_1, vtmpa_4_7b_3, 0x30 );
va4_7b2 = _mm256_permute2f128_pd( vtmpa_4_7b_3, vtmpa_4_7b_1, 0x30 ); va4_7b2 = _mm256_permute2f128_pd( vtmpa_4_7b_3, vtmpa_4_7b_1, 0x30 );
if( rs_c == 1 ) __m128d vzero = _mm_setzero_pd( );
if( _mm_comieq_sd( _mm256_castpd256_pd128(vbeta), vzero ) )
{ {
// Calculate address // Calculate address
c00 = ( c + 0*rs_c + 0*cs_c ); c00 = ( c + 0 + 0*cs_c );
// Scale by alpha
vtmp = _mm256_mul_pd( valpha, va0_3b0);
// Store back to memory
_mm256_store_pd( c00, vtmp );
// Calculate address
c40 = ( c + 4 + 0*cs_c );
// Scale by alpha
vtmp = _mm256_mul_pd( valpha, va4_7b0);
// Store back to memory
_mm256_store_pd( c40, vtmp );
// Calculate address
c01 = ( c + 0 + 1*cs_c );
// Scale by alpha
vtmp = _mm256_mul_pd( valpha, va0_3b1);
// Store back to memory
_mm256_store_pd( c01, vtmp );
// Calculate address
c41 = ( c + 4 + 1*cs_c );
// Scale by alpha
vtmp = _mm256_mul_pd( valpha, va4_7b1);
// Store back to memory
_mm256_store_pd( c41, vtmp );
// Calculate address
c02 = ( c + 0 + 2*cs_c );
// Scale by alpha
vtmp = _mm256_mul_pd( valpha, va0_3b2);
// Store back to memory
_mm256_store_pd( c02, vtmp );
// Calculate address
c42 = ( c + 4 + 2*cs_c );
// Scale by alpha
vtmp = _mm256_mul_pd( valpha, va4_7b2);
// Store back to memory
_mm256_store_pd( c42, vtmp );
// Calculate address
c03 = ( c + 0 + 3*cs_c );
// Scale by alpha
vtmp = _mm256_mul_pd( valpha, va0_3b3);
// Store back to memory
_mm256_store_pd( c03, vtmp );
// Calculate address
c43 = ( c + 4 + 3*cs_c );
// Scale by alpha
vtmp = _mm256_mul_pd( valpha, va4_7b3);
// Store back to memory
_mm256_store_pd( c43, vtmp );
}
else
{
// Calculate address
c00 = ( c + 0 + 0*cs_c );
// Load // Load
//vc0_3_0 = _mm256_load_pd( c + 0*rs_c + 0*cs_c ); //vc0_3_0 = _mm256_load_pd( c + 0 + 0*cs_c );
vc0_3_0 = _mm256_load_pd( c00 ); vc0_3_0 = _mm256_load_pd( c00 );
// Scale by alpha // Scale by alpha
vtmp = _mm256_mul_pd( valpha, va0_3b0); vtmp = _mm256_mul_pd( valpha, va0_3b0);
@@ -326,9 +389,9 @@ void bli_dgemm_sandybridge_int_8x4
_mm256_store_pd( c00, vc0_3_0 ); _mm256_store_pd( c00, vc0_3_0 );
// Calculate address // Calculate address
c40 = ( c + 4*rs_c + 0*cs_c ); c40 = ( c + 4 + 0*cs_c );
// Load // Load
//vc4_7_0 = _mm256_load_pd( c + 4*rs_c + 0*cs_c ); //vc4_7_0 = _mm256_load_pd( c + 4 + 0*cs_c );
vc4_7_0 = _mm256_load_pd( c40 ); vc4_7_0 = _mm256_load_pd( c40 );
// Scale by alpha // Scale by alpha
vtmp = _mm256_mul_pd( valpha, va4_7b0); vtmp = _mm256_mul_pd( valpha, va4_7b0);
@@ -340,9 +403,9 @@ void bli_dgemm_sandybridge_int_8x4
_mm256_store_pd( c40, vc4_7_0 ); _mm256_store_pd( c40, vc4_7_0 );
// Calculate address // Calculate address
c01 = ( c + 0*rs_c + 1*cs_c ); c01 = ( c + 0 + 1*cs_c );
// Load // Load
//vc0_3_1 = _mm256_load_pd( c + 0*rs_c + 1*cs_c ); //vc0_3_1 = _mm256_load_pd( c + 0 + 1*cs_c );
vc0_3_1 = _mm256_load_pd( c01 ); vc0_3_1 = _mm256_load_pd( c01 );
// Scale by alpha // Scale by alpha
vtmp = _mm256_mul_pd( valpha, va0_3b1); vtmp = _mm256_mul_pd( valpha, va0_3b1);
@@ -354,9 +417,9 @@ void bli_dgemm_sandybridge_int_8x4
_mm256_store_pd( c01, vc0_3_1 ); _mm256_store_pd( c01, vc0_3_1 );
// Calculate address // Calculate address
c41 = ( c + 4*rs_c + 1*cs_c ); c41 = ( c + 4 + 1*cs_c );
// Load // Load
//vc4_7_1 = _mm256_load_pd( c + 4*rs_c + 1*cs_c ); //vc4_7_1 = _mm256_load_pd( c + 4 + 1*cs_c );
vc4_7_1 = _mm256_load_pd( c41 ); vc4_7_1 = _mm256_load_pd( c41 );
// Scale by alpha // Scale by alpha
vtmp = _mm256_mul_pd( valpha, va4_7b1); vtmp = _mm256_mul_pd( valpha, va4_7b1);
@@ -368,9 +431,9 @@ void bli_dgemm_sandybridge_int_8x4
_mm256_store_pd( c41, vc4_7_1 ); _mm256_store_pd( c41, vc4_7_1 );
// Calculate address // Calculate address
c02 = ( c + 0*rs_c + 2*cs_c ); c02 = ( c + 0 + 2*cs_c );
// Load // Load
//vc0_3_2 = _mm256_load_pd( c + 0*rs_c + 2*cs_c ); //vc0_3_2 = _mm256_load_pd( c + 0 + 2*cs_c );
vc0_3_2 = _mm256_load_pd( c02 ); vc0_3_2 = _mm256_load_pd( c02 );
// Scale by alpha // Scale by alpha
vtmp = _mm256_mul_pd( valpha, va0_3b2); vtmp = _mm256_mul_pd( valpha, va0_3b2);
@@ -382,9 +445,9 @@ void bli_dgemm_sandybridge_int_8x4
_mm256_store_pd( c02, vc0_3_2 ); _mm256_store_pd( c02, vc0_3_2 );
// Calculate address // Calculate address
c42 = ( c + 4*rs_c + 2*cs_c ); c42 = ( c + 4 + 2*cs_c );
// Load // Load
//vc4_7_2 = _mm256_load_pd( c + 4*rs_c + 2*cs_c ); //vc4_7_2 = _mm256_load_pd( c + 4 + 2*cs_c );
vc4_7_2 = _mm256_load_pd( c42 ); vc4_7_2 = _mm256_load_pd( c42 );
// Scale by alpha // Scale by alpha
vtmp = _mm256_mul_pd( valpha, va4_7b2); vtmp = _mm256_mul_pd( valpha, va4_7b2);
@@ -396,9 +459,9 @@ void bli_dgemm_sandybridge_int_8x4
_mm256_store_pd( c42, vc4_7_2 ); _mm256_store_pd( c42, vc4_7_2 );
// Calculate address // Calculate address
c03 = ( c + 0*rs_c + 3*cs_c ); c03 = ( c + 0 + 3*cs_c );
// Load // Load
//vc0_3_3 = _mm256_load_pd( c + 0*rs_c + 3*cs_c ); //vc0_3_3 = _mm256_load_pd( c + 0 + 3*cs_c );
vc0_3_3 = _mm256_load_pd( c03 ); vc0_3_3 = _mm256_load_pd( c03 );
// Scale by alpha // Scale by alpha
vtmp = _mm256_mul_pd( valpha, va0_3b3); vtmp = _mm256_mul_pd( valpha, va0_3b3);
@@ -410,9 +473,9 @@ void bli_dgemm_sandybridge_int_8x4
_mm256_store_pd( c03, vc0_3_3 ); _mm256_store_pd( c03, vc0_3_3 );
// Calculate address // Calculate address
c43 = ( c + 4*rs_c + 3*cs_c ); c43 = ( c + 4 + 3*cs_c );
// Load // Load
//vc4_7_3 = _mm256_load_pd( c + 4*rs_c + 3*cs_c ); //vc4_7_3 = _mm256_load_pd( c + 4 + 3*cs_c );
vc4_7_3 = _mm256_load_pd( c43 ); vc4_7_3 = _mm256_load_pd( c43 );
// Scale by alpha // Scale by alpha
vtmp = _mm256_mul_pd( valpha, va4_7b3); vtmp = _mm256_mul_pd( valpha, va4_7b3);
@@ -422,211 +485,9 @@ void bli_dgemm_sandybridge_int_8x4
vc4_7_3 = _mm256_add_pd( vc4_7_3, vtmp ); vc4_7_3 = _mm256_add_pd( vc4_7_3, vtmp );
// Store back to memory // Store back to memory
_mm256_store_pd( c43, vc4_7_3 ); _mm256_store_pd( c43, vc4_7_3 );
}
else
{
// Calculate address
c00 = ( c + 0*rs_c + 0*cs_c );
// Load
//vc0_3_0 = _mm256_load_pd( c + 0*rs_c + 0*cs_c );
vc0_3_0 = _mm256_set_pd( *(c + 3*rs_c + 0*cs_c ),
*(c + 2*rs_c + 0*cs_c ),
*(c + 1*rs_c + 0*cs_c ),
*(c + 0*rs_c + 0*cs_c ) );
// Scale by alpha
vtmp = _mm256_mul_pd( valpha, va0_3b0);
// Scale by beta
vc0_3_0 = _mm256_mul_pd( vbeta, vc0_3_0 );
// Add gemm result
vc0_3_0 = _mm256_add_pd( vc0_3_0, vtmp );
// Store back to memory
//_mm256_store_pd( c00, vc0_3_0 );
aa = _mm256_extractf128_pd( vc0_3_0, 0 ) ;
bb = _mm256_extractf128_pd( vc0_3_0, 1 ) ;
_mm_storel_pd( c + 0*rs_c + 0*cs_c, aa );
_mm_storeh_pd( c + 1*rs_c + 0*cs_c, aa );
_mm_storel_pd( c + 2*rs_c + 0*cs_c, bb );
_mm_storeh_pd( c + 3*rs_c + 0*cs_c, bb );
// Calculate address
c40 = ( c + 4*rs_c + 0*cs_c );
// Load
//vc4_7_0 = _mm256_load_pd( c + 4*rs_c + 0*cs_c );
vc4_7_0 = _mm256_set_pd( *(c + 7*rs_c + 0*cs_c ),
*(c + 6*rs_c + 0*cs_c ),
*(c + 5*rs_c + 0*cs_c ),
*(c + 4*rs_c + 0*cs_c ) );
// Scale by alpha
vtmp = _mm256_mul_pd( valpha, va4_7b0);
// Scale by beta
vc4_7_0 = _mm256_mul_pd( vbeta, vc4_7_0 );
// Add gemm result
vc4_7_0 = _mm256_add_pd( vc4_7_0, vtmp );
// Store back to memory
//_mm256_store_pd( c40, vc4_7_0 );
aa = _mm256_extractf128_pd( vc4_7_0, 0 ) ;
bb = _mm256_extractf128_pd( vc4_7_0, 1 ) ;
_mm_storel_pd( c + 4*rs_c + 0*cs_c, aa );
_mm_storeh_pd( c + 5*rs_c + 0*cs_c, aa );
_mm_storel_pd( c + 6*rs_c + 0*cs_c, bb );
_mm_storeh_pd( c + 7*rs_c + 0*cs_c, bb );
// Calculate address
c01 = ( c + 0*rs_c + 1*cs_c );
// Load
//vc0_3_1 = _mm256_load_pd( c + 0*rs_c + 1*cs_c );
vc0_3_1 = _mm256_set_pd( *(c + 3*rs_c + 1*cs_c ),
*(c + 2*rs_c + 1*cs_c ),
*(c + 1*rs_c + 1*cs_c ),
*(c + 0*rs_c + 1*cs_c ) );
// Scale by alpha
vtmp = _mm256_mul_pd( valpha, va0_3b1);
// Scale by beta
vc0_3_1 = _mm256_mul_pd( vbeta, vc0_3_1 );
// Add gemm result
vc0_3_1 = _mm256_add_pd( vc0_3_1, vtmp );
// Store back to memory
//_mm256_store_pd( c01, vc0_3_1 );
aa = _mm256_extractf128_pd( vc0_3_1, 0 ) ;
bb = _mm256_extractf128_pd( vc0_3_1, 1 ) ;
_mm_storel_pd( c + 0*rs_c + 1*cs_c, aa );
_mm_storeh_pd( c + 1*rs_c + 1*cs_c, aa );
_mm_storel_pd( c + 2*rs_c + 1*cs_c, bb );
_mm_storeh_pd( c + 3*rs_c + 1*cs_c, bb );
// Calculate address
c41 = ( c + 4*rs_c + 1*cs_c );
// Load
//vc4_7_1 = _mm256_load_pd( c + 4*rs_c + 1*cs_c );
vc4_7_1 = _mm256_set_pd( *(c + 7*rs_c + 1*cs_c ),
*(c + 6*rs_c + 1*cs_c ),
*(c + 5*rs_c + 1*cs_c ),
*(c + 4*rs_c + 1*cs_c ) );
// Scale by alpha
vtmp = _mm256_mul_pd( valpha, va4_7b1);
// Scale by beta
vc4_7_1 = _mm256_mul_pd( vbeta, vc4_7_1 );
// Add gemm result
vc4_7_1 = _mm256_add_pd( vc4_7_1, vtmp );
// Store back to memory
//_mm256_store_pd( c41, vc4_7_1 );
aa = _mm256_extractf128_pd( vc4_7_1, 0 ) ;
bb = _mm256_extractf128_pd( vc4_7_1, 1 ) ;
_mm_storel_pd( c + 4*rs_c + 1*cs_c, aa );
_mm_storeh_pd( c + 5*rs_c + 1*cs_c, aa );
_mm_storel_pd( c + 6*rs_c + 1*cs_c, bb );
_mm_storeh_pd( c + 7*rs_c + 1*cs_c, bb );
// Calculate address
c02 = ( c + 0*rs_c + 2*cs_c );
// Load
//vc0_3_2 = _mm256_load_pd( c + 0*rs_c + 2*cs_c );
vc0_3_2 = _mm256_set_pd( *(c + 3*rs_c + 2*cs_c ),
*(c + 2*rs_c + 2*cs_c ),
*(c + 1*rs_c + 2*cs_c ),
*(c + 0*rs_c + 2*cs_c ) );
// Scale by alpha
vtmp = _mm256_mul_pd( valpha, va0_3b2);
// Scale by beta
vc0_3_2 = _mm256_mul_pd( vbeta, vc0_3_2 );
// Add gemm result
vc0_3_2 = _mm256_add_pd( vc0_3_2, vtmp );
// Store back to memory
//_mm256_store_pd( c02, vc0_3_2 );
aa = _mm256_extractf128_pd( vc0_3_2, 0 ) ;
bb = _mm256_extractf128_pd( vc0_3_2, 1 ) ;
_mm_storel_pd( c + 0*rs_c + 2*cs_c, aa );
_mm_storeh_pd( c + 1*rs_c + 2*cs_c, aa );
_mm_storel_pd( c + 2*rs_c + 2*cs_c, bb );
_mm_storeh_pd( c + 3*rs_c + 2*cs_c, bb );
// Calculate address
c42 = ( c + 4*rs_c + 2*cs_c );
// Load
//vc4_7_2 = _mm256_load_pd( c + 4*rs_c + 2*cs_c );
vc4_7_2 = _mm256_set_pd( *(c + 7*rs_c + 2*cs_c ),
*(c + 6*rs_c + 2*cs_c ),
*(c + 5*rs_c + 2*cs_c ),
*(c + 4*rs_c + 2*cs_c ) );
// Scale by alpha
vtmp = _mm256_mul_pd( valpha, va4_7b2);
// Scale by beta
vc4_7_2 = _mm256_mul_pd( vbeta, vc4_7_2 );
// Add gemm result
vc4_7_2 = _mm256_add_pd( vc4_7_2, vtmp );
// Store back to memory
//_mm256_store_pd( c42, vc4_7_2 );
aa = _mm256_extractf128_pd( vc4_7_2, 0 ) ;
bb = _mm256_extractf128_pd( vc4_7_2, 1 ) ;
_mm_storel_pd( c + 4*rs_c + 2*cs_c, aa );
_mm_storeh_pd( c + 5*rs_c + 2*cs_c, aa );
_mm_storel_pd( c + 6*rs_c + 2*cs_c, bb );
_mm_storeh_pd( c + 7*rs_c + 2*cs_c, bb );
// Calculate address
c03 = ( c + 0*rs_c + 3*cs_c );
// Load
//vc0_3_3 = _mm256_load_pd( c + 0*rs_c + 3*cs_c );
vc0_3_3 = _mm256_set_pd( *(c + 3*rs_c + 3*cs_c ),
*(c + 2*rs_c + 3*cs_c ),
*(c + 1*rs_c + 3*cs_c ),
*(c + 0*rs_c + 3*cs_c ) );
// Scale by alpha
vtmp = _mm256_mul_pd( valpha, va0_3b3);
// Scale by beta
vc0_3_3 = _mm256_mul_pd( vbeta, vc0_3_3 );
// Add gemm result
vc0_3_3 = _mm256_add_pd( vc0_3_3, vtmp );
// Store back to memory
//_mm256_store_pd( c03, vc0_3_3 );
aa = _mm256_extractf128_pd( vc0_3_3, 0 ) ;
bb = _mm256_extractf128_pd( vc0_3_3, 1 ) ;
_mm_storel_pd( c + 0*rs_c + 3*cs_c, aa );
_mm_storeh_pd( c + 1*rs_c + 3*cs_c, aa );
_mm_storel_pd( c + 2*rs_c + 3*cs_c, bb );
_mm_storeh_pd( c + 3*rs_c + 3*cs_c, bb );
// Calculate address
c43 = ( c + 4*rs_c + 3*cs_c );
// Load
//vc4_7_3 = _mm256_load_pd( c + 4*rs_c + 3*cs_c );
vc4_7_3 = _mm256_set_pd( *(c + 7*rs_c + 3*cs_c ),
*(c + 6*rs_c + 3*cs_c ),
*(c + 5*rs_c + 3*cs_c ),
*(c + 4*rs_c + 3*cs_c ) );
// Scale by alpha
vtmp = _mm256_mul_pd( valpha, va4_7b3);
// Scale by beta
vc4_7_3 = _mm256_mul_pd( vbeta, vc4_7_3 );
// Add gemm result
vc4_7_3 = _mm256_add_pd( vc4_7_3, vtmp );
// Store back to memory
//_mm256_store_pd( c43, vc4_7_3 );
aa = _mm256_extractf128_pd( vc4_7_3, 0 ) ;
bb = _mm256_extractf128_pd( vc4_7_3, 1 ) ;
_mm_storel_pd( c + 4*rs_c + 3*cs_c, aa );
_mm_storeh_pd( c + 5*rs_c + 3*cs_c, aa );
_mm_storel_pd( c + 6*rs_c + 3*cs_c, bb );
_mm_storeh_pd( c + 7*rs_c + 3*cs_c, bb );
} }
GEMM_UKR_FLUSH_CT( d );
} }
@@ -634,7 +495,9 @@ void bli_dgemm_sandybridge_int_8x4
#if 0 #if 0
void bli_cgemm_sandybridge_int_8x4 void bli_cgemm_sandybridge_int_8x4
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
scomplex* restrict alpha, scomplex* restrict alpha,
scomplex* restrict a, scomplex* restrict a,
scomplex* restrict b, scomplex* restrict b,
@@ -652,7 +515,9 @@ void bli_cgemm_sandybridge_int_8x4
#if 0 #if 0
void bli_zgemm_sandybridge_int_4x4 void bli_zgemm_sandybridge_int_4x4
( (
dim_t k0, dim_t m,
dim_t n,
dim_t k,
dcomplex* restrict alpha, dcomplex* restrict alpha,
dcomplex* restrict a, dcomplex* restrict a,
dcomplex* restrict b, dcomplex* restrict b,

View File

@@ -287,24 +287,28 @@ static int64_t offsets[16] __attribute__((aligned(64))) =
{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15}; { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15};
void bli_dgemm_skx_asm_16x12_l2( void bli_dgemm_skx_asm_16x12_l2
dim_t k_, (
double* restrict alpha, dim_t m,
double* restrict a, dim_t n,
double* restrict b, dim_t k_,
double* restrict beta, double* restrict alpha,
double* restrict c, inc_t rs_c_, inc_t cs_c_, double* restrict a,
auxinfo_t* data, double* restrict b,
cntx_t* restrict cntx double* restrict beta,
) double* restrict c, inc_t rs_c_, inc_t cs_c_,
auxinfo_t* data,
cntx_t* restrict cntx
)
{ {
(void)data; (void)data;
(void)cntx; (void)cntx;
const int64_t* offsetPtr = &offsets[0]; int64_t k = k_;
const int64_t k = k_; int64_t rs_c = rs_c_;
const int64_t rs_c = rs_c_; int64_t cs_c = cs_c_;
const int64_t cs_c = cs_c_;
GEMM_UKR_SETUP_CT( d, 16, 12, false );
BEGIN_ASM() BEGIN_ASM()
@@ -464,62 +468,26 @@ void bli_dgemm_skx_asm_16x12_l2(
MOV(RAX, VAR(cs_c)) MOV(RAX, VAR(cs_c))
LEA(RAX, MEM(,RAX,8)) LEA(RAX, MEM(,RAX,8))
MOV(RBX, VAR(rs_c))
LEA(RBX, MEM(,RBX,8))
// Check if C is column stride. If not, jump to the slow scattered update VCOMISD(XMM(1), XMM(7))
CMP(RBX, IMM(1)) JE(COLSTORBZ)
JNE(SCATTEREDUPDATE)
VCOMISD(XMM(1), XMM(7)) UPDATE_C( 8, 9,10,11)
JE(COLSTORBZ) UPDATE_C(12,13,14,15)
UPDATE_C(16,17,18,19)
UPDATE_C( 8, 9,10,11) UPDATE_C(20,21,22,23)
UPDATE_C(12,13,14,15) UPDATE_C(24,25,26,27)
UPDATE_C(16,17,18,19) UPDATE_C(28,29,30,31)
UPDATE_C(20,21,22,23)
UPDATE_C(24,25,26,27)
UPDATE_C(28,29,30,31)
JMP(END)
LABEL(COLSTORBZ)
UPDATE_C_BZ( 8, 9,10,11)
UPDATE_C_BZ(12,13,14,15)
UPDATE_C_BZ(16,17,18,19)
UPDATE_C_BZ(20,21,22,23)
UPDATE_C_BZ(24,25,26,27)
UPDATE_C_BZ(28,29,30,31)
JMP(END) JMP(END)
LABEL(SCATTEREDUPDATE) LABEL(COLSTORBZ)
MOV(RDI, VAR(offsetPtr)) UPDATE_C_BZ( 8, 9,10,11)
VMOVDQA64(ZMM(2), MEM(RDI,0*64)) UPDATE_C_BZ(12,13,14,15)
VMOVDQA64(ZMM(3), MEM(RDI,1*64)) UPDATE_C_BZ(16,17,18,19)
VPBROADCASTQ(ZMM(6), RBX) UPDATE_C_BZ(20,21,22,23)
VPMULLQ(ZMM(2), ZMM(6), ZMM(2)) UPDATE_C_BZ(24,25,26,27)
VPMULLQ(ZMM(3), ZMM(6), ZMM(3)) UPDATE_C_BZ(28,29,30,31)
VCOMISD(XMM(1), XMM(7))
JE(SCATTERBZ)
UPDATE_C_ROW_SCATTERED( 8, 9,10,11)
UPDATE_C_ROW_SCATTERED(12,13,14,15)
UPDATE_C_ROW_SCATTERED(16,17,18,19)
UPDATE_C_ROW_SCATTERED(20,21,22,23)
UPDATE_C_ROW_SCATTERED(24,25,26,27)
UPDATE_C_ROW_SCATTERED(28,29,30,31)
JMP(END)
LABEL(SCATTERBZ)
UPDATE_C_BZ_ROW_SCATTERED( 8, 9,10,11)
UPDATE_C_BZ_ROW_SCATTERED(12,13,14,15)
UPDATE_C_BZ_ROW_SCATTERED(16,17,18,19)
UPDATE_C_BZ_ROW_SCATTERED(20,21,22,23)
UPDATE_C_BZ_ROW_SCATTERED(24,25,26,27)
UPDATE_C_BZ_ROW_SCATTERED(28,29,30,31)
LABEL(END) LABEL(END)
@@ -535,8 +503,7 @@ void bli_dgemm_skx_asm_16x12_l2(
[beta] "m" (beta), [beta] "m" (beta),
[c] "m" (c), [c] "m" (c),
[rs_c] "m" (rs_c), [rs_c] "m" (rs_c),
[cs_c] "m" (cs_c), [cs_c] "m" (cs_c)
[offsetPtr] "m" (offsetPtr)
: // register clobber list : // register clobber list
"rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12", "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12",
"r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5",
@@ -545,4 +512,6 @@ void bli_dgemm_skx_asm_16x12_l2(
"zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29",
"zmm30", "zmm31", "memory" "zmm30", "zmm31", "memory"
) )
GEMM_UKR_FLUSH_CT( d );
} }

View File

@@ -153,24 +153,28 @@
static int64_t offsets[16] __attribute__((aligned(64))) = static int64_t offsets[16] __attribute__((aligned(64))) =
{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15}; { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15};
void bli_dgemm_skx_asm_16x14( void bli_dgemm_skx_asm_16x14
dim_t k_, (
double* restrict alpha, dim_t m,
double* restrict a, dim_t n,
double* restrict b, dim_t k_,
double* restrict beta, double* restrict alpha,
double* restrict c, inc_t rs_c_, inc_t cs_c_, double* restrict a,
auxinfo_t* data, double* restrict b,
cntx_t* restrict cntx double* restrict beta,
) double* restrict c, inc_t rs_c_, inc_t cs_c_,
auxinfo_t* data,
cntx_t* restrict cntx
)
{ {
(void)data; (void)data;
(void)cntx; (void)cntx;
const int64_t* offsetPtr = &offsets[0]; int64_t k = k_;
const int64_t k = k_; int64_t rs_c = rs_c_;
const int64_t rs_c = rs_c_*8; int64_t cs_c = cs_c_;
const int64_t cs_c = cs_c_*8;
GEMM_UKR_SETUP_CT( d, 16, 14, false );
BEGIN_ASM() BEGIN_ASM()
@@ -220,6 +224,8 @@ void bli_dgemm_skx_asm_16x14(
MOV(R12, VAR(rs_c)) MOV(R12, VAR(rs_c))
MOV(R10, VAR(cs_c)) MOV(R10, VAR(cs_c))
LEA(R12, MEM(,R12,8))
LEA(R10, MEM(,R10,8))
MOV(RDI, RSI) MOV(RDI, RSI)
AND(RSI, IMM(3)) AND(RSI, IMM(3))
@@ -320,119 +326,41 @@ void bli_dgemm_skx_asm_16x14(
MOV(RAX, R12) MOV(RAX, R12)
MOV(RBX, R10) MOV(RBX, R10)
// Check if C is column stride. VCOMISD(XMM(1), XMM(2))
CMP(RAX, IMM(8)) JE(COLSTORBZ)
JNE(SCATTEREDUPDATE)
VCOMISD(XMM(1), XMM(2)) UPDATE_C( 4, 5)
JE(COLSTORBZ) UPDATE_C( 6, 7)
UPDATE_C( 8, 9)
UPDATE_C( 4, 5) UPDATE_C(10,11)
UPDATE_C( 6, 7) UPDATE_C(12,13)
UPDATE_C( 8, 9) UPDATE_C(14,15)
UPDATE_C(10,11) UPDATE_C(16,17)
UPDATE_C(12,13) UPDATE_C(18,19)
UPDATE_C(14,15) UPDATE_C(20,21)
UPDATE_C(16,17) UPDATE_C(22,23)
UPDATE_C(18,19) UPDATE_C(24,25)
UPDATE_C(20,21) UPDATE_C(26,27)
UPDATE_C(22,23) UPDATE_C(28,29)
UPDATE_C(24,25) UPDATE_C(30,31)
UPDATE_C(26,27)
UPDATE_C(28,29)
UPDATE_C(30,31)
JMP(END)
LABEL(COLSTORBZ)
UPDATE_C_BZ( 4, 5)
UPDATE_C_BZ( 6, 7)
UPDATE_C_BZ( 8, 9)
UPDATE_C_BZ(10,11)
UPDATE_C_BZ(12,13)
UPDATE_C_BZ(14,15)
UPDATE_C_BZ(16,17)
UPDATE_C_BZ(18,19)
UPDATE_C_BZ(20,21)
UPDATE_C_BZ(22,23)
UPDATE_C_BZ(24,25)
UPDATE_C_BZ(26,27)
UPDATE_C_BZ(28,29)
UPDATE_C_BZ(30,31)
JMP(END) JMP(END)
LABEL(SCATTEREDUPDATE) LABEL(COLSTORBZ)
VMULPD(ZMM( 4), ZMM( 4), ZMM(0)) UPDATE_C_BZ( 4, 5)
VMULPD(ZMM( 5), ZMM( 5), ZMM(0)) UPDATE_C_BZ( 6, 7)
VMULPD(ZMM( 6), ZMM( 6), ZMM(0)) UPDATE_C_BZ( 8, 9)
VMULPD(ZMM( 7), ZMM( 7), ZMM(0)) UPDATE_C_BZ(10,11)
VMULPD(ZMM( 8), ZMM( 8), ZMM(0)) UPDATE_C_BZ(12,13)
VMULPD(ZMM( 9), ZMM( 9), ZMM(0)) UPDATE_C_BZ(14,15)
VMULPD(ZMM(10), ZMM(10), ZMM(0)) UPDATE_C_BZ(16,17)
VMULPD(ZMM(11), ZMM(11), ZMM(0)) UPDATE_C_BZ(18,19)
VMULPD(ZMM(12), ZMM(12), ZMM(0)) UPDATE_C_BZ(20,21)
VMULPD(ZMM(13), ZMM(13), ZMM(0)) UPDATE_C_BZ(22,23)
VMULPD(ZMM(14), ZMM(14), ZMM(0)) UPDATE_C_BZ(24,25)
VMULPD(ZMM(15), ZMM(15), ZMM(0)) UPDATE_C_BZ(26,27)
VMULPD(ZMM(16), ZMM(16), ZMM(0)) UPDATE_C_BZ(28,29)
VMULPD(ZMM(17), ZMM(17), ZMM(0)) UPDATE_C_BZ(30,31)
VMULPD(ZMM(18), ZMM(18), ZMM(0))
VMULPD(ZMM(19), ZMM(19), ZMM(0))
VMULPD(ZMM(20), ZMM(20), ZMM(0))
VMULPD(ZMM(21), ZMM(21), ZMM(0))
VMULPD(ZMM(22), ZMM(22), ZMM(0))
VMULPD(ZMM(23), ZMM(23), ZMM(0))
VMULPD(ZMM(24), ZMM(24), ZMM(0))
VMULPD(ZMM(25), ZMM(25), ZMM(0))
VMULPD(ZMM(26), ZMM(26), ZMM(0))
VMULPD(ZMM(27), ZMM(27), ZMM(0))
VMULPD(ZMM(28), ZMM(28), ZMM(0))
VMULPD(ZMM(29), ZMM(29), ZMM(0))
VMULPD(ZMM(30), ZMM(30), ZMM(0))
VMULPD(ZMM(31), ZMM(31), ZMM(0))
VCOMISD(XMM(1), XMM(2))
MOV(RDI, VAR(offsetPtr))
VPBROADCASTQ(ZMM(0), RAX)
VPMULLQ(ZMM(2), ZMM(0), MEM(RDI))
VPMULLQ(ZMM(3), ZMM(0), MEM(RDI,64))
JE(SCATTERBZ)
UPDATE_C_COL_SCATTERED( 4, 5)
UPDATE_C_COL_SCATTERED( 6, 7)
UPDATE_C_COL_SCATTERED( 8, 9)
UPDATE_C_COL_SCATTERED(10,11)
UPDATE_C_COL_SCATTERED(12,13)
UPDATE_C_COL_SCATTERED(14,15)
UPDATE_C_COL_SCATTERED(16,17)
UPDATE_C_COL_SCATTERED(18,19)
UPDATE_C_COL_SCATTERED(20,21)
UPDATE_C_COL_SCATTERED(22,23)
UPDATE_C_COL_SCATTERED(24,25)
UPDATE_C_COL_SCATTERED(26,27)
UPDATE_C_COL_SCATTERED(28,29)
UPDATE_C_COL_SCATTERED(30,31)
JMP(END)
LABEL(SCATTERBZ)
UPDATE_C_BZ_COL_SCATTERED( 4, 5)
UPDATE_C_BZ_COL_SCATTERED( 6, 7)
UPDATE_C_BZ_COL_SCATTERED( 8, 9)
UPDATE_C_BZ_COL_SCATTERED(10,11)
UPDATE_C_BZ_COL_SCATTERED(12,13)
UPDATE_C_BZ_COL_SCATTERED(14,15)
UPDATE_C_BZ_COL_SCATTERED(16,17)
UPDATE_C_BZ_COL_SCATTERED(18,19)
UPDATE_C_BZ_COL_SCATTERED(20,21)
UPDATE_C_BZ_COL_SCATTERED(22,23)
UPDATE_C_BZ_COL_SCATTERED(24,25)
UPDATE_C_BZ_COL_SCATTERED(26,27)
UPDATE_C_BZ_COL_SCATTERED(28,29)
UPDATE_C_BZ_COL_SCATTERED(30,31)
LABEL(END) LABEL(END)
@@ -449,8 +377,7 @@ void bli_dgemm_skx_asm_16x14(
[beta] "m" (beta), [beta] "m" (beta),
[c] "m" (c), [c] "m" (c),
[rs_c] "m" (rs_c), [rs_c] "m" (rs_c),
[cs_c] "m" (cs_c), [cs_c] "m" (cs_c)
[offsetPtr] "m" (offsetPtr)
: // register clobber list : // register clobber list
"rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12", "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12",
"r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5",
@@ -459,4 +386,6 @@ void bli_dgemm_skx_asm_16x14(
"zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29",
"zmm30", "zmm31", "memory" "zmm30", "zmm31", "memory"
) )
GEMM_UKR_FLUSH_CT( d );
} }

View File

@@ -317,24 +317,28 @@ ahead*/
static int64_t offsets[16] __attribute__((aligned(64))) = static int64_t offsets[16] __attribute__((aligned(64))) =
{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15}; { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15};
void bli_sgemm_skx_asm_32x12_l2( void bli_sgemm_skx_asm_32x12_l2
dim_t k_, (
float* restrict alpha, dim_t m,
float* restrict a, dim_t n,
float* restrict b, dim_t k_,
float* restrict beta, float* restrict alpha,
float* restrict c, inc_t rs_c_, inc_t cs_c_, float* restrict a,
auxinfo_t* data, float* restrict b,
cntx_t* restrict cntx float* restrict beta,
) float* restrict c, inc_t rs_c_, inc_t cs_c_,
auxinfo_t* data,
cntx_t* restrict cntx
)
{ {
(void)data; (void)data;
(void)cntx; (void)cntx;
const int64_t* offsetPtr = &offsets[0]; int64_t k = k_;
const int64_t k = k_; int64_t rs_c = rs_c_;
const int64_t rs_c = rs_c_; int64_t cs_c = cs_c_;
const int64_t cs_c = cs_c_;
GEMM_UKR_SETUP_CT( s, 32, 12, false );
BEGIN_ASM() BEGIN_ASM()
@@ -381,7 +385,7 @@ void bli_sgemm_skx_asm_32x12_l2(
#endif #endif
#ifdef PREFETCH_B_BEFORE #ifdef PREFETCH_B_BEFORE
/* Prefetching 3 cachlines of B (4 iterations worth of data /* Prefetching 3 cachlines of B (4 iterations worth of data
(12 (NR) x 4 (sizeof(float)) x 4 iter /64 = 3 cachelines) */ (12 (NR) x 4 (sizeof(float)) x 4 iter /64 = 3 cachelines) */
PREFETCH(0, MEM(RBX,0*64)) PREFETCH(0, MEM(RBX,0*64))
PREFETCH(0, MEM(RBX,1*64)) PREFETCH(0, MEM(RBX,1*64))
@@ -485,66 +489,26 @@ void bli_sgemm_skx_asm_32x12_l2(
MOV(RAX, VAR(cs_c)) MOV(RAX, VAR(cs_c))
LEA(RAX, MEM(,RAX,4)) LEA(RAX, MEM(,RAX,4))
MOV(RBX, VAR(rs_c))
LEA(RBX, MEM(,RBX,4))
VCOMISS(XMM(1), XMM(7))
JE(COLSTORBZ)
// Check if C is column major (rs_c = 1). If not, jump to the slow scattered update UPDATE_C( 8, 9,10,11)
CMP(RBX, IMM(4)) UPDATE_C(12,13,14,15)
JNE(SCATTEREDUPDATE) UPDATE_C(16,17,18,19)
UPDATE_C(20,21,22,23)
VCOMISS(XMM(1), XMM(7)) UPDATE_C(24,25,26,27)
JE(COLSTORBZ) UPDATE_C(28,29,30,31)
UPDATE_C( 8, 9,10,11)
UPDATE_C(12,13,14,15)
UPDATE_C(16,17,18,19)
UPDATE_C(20,21,22,23)
UPDATE_C(24,25,26,27)
UPDATE_C(28,29,30,31)
JMP(END)
LABEL(COLSTORBZ)
UPDATE_C_BZ( 8, 9,10,11)
UPDATE_C_BZ(12,13,14,15)
UPDATE_C_BZ(16,17,18,19)
UPDATE_C_BZ(20,21,22,23)
UPDATE_C_BZ(24,25,26,27)
UPDATE_C_BZ(28,29,30,31)
JMP(END) JMP(END)
LABEL(SCATTEREDUPDATE) LABEL(COLSTORBZ)
LEA(RDX, MEM(RCX,RBX,8)) UPDATE_C_BZ( 8, 9,10,11)
LEA(RDX, MEM(RDX,RBX,8)) UPDATE_C_BZ(12,13,14,15)
UPDATE_C_BZ(16,17,18,19)
MOV(RDI, VAR(offsetPtr)) UPDATE_C_BZ(20,21,22,23)
VMOVDQA64(ZMM(2), MEM(RDI,0*64)) UPDATE_C_BZ(24,25,26,27)
VMOVDQA64(ZMM(3), MEM(RDI,1*64)) UPDATE_C_BZ(28,29,30,31)
VPBROADCASTQ(ZMM(6), RBX)
VPMULLQ(ZMM(2), ZMM(6), ZMM(2))
VPMULLQ(ZMM(3), ZMM(6), ZMM(3))
VCOMISS(XMM(1), XMM(7))
JE(SCATTERBZ)
UPDATE_C_ROW_SCATTERED( 8, 9,10,11)
UPDATE_C_ROW_SCATTERED(12,13,14,15)
UPDATE_C_ROW_SCATTERED(16,17,18,19)
UPDATE_C_ROW_SCATTERED(20,21,22,23)
UPDATE_C_ROW_SCATTERED(24,25,26,27)
UPDATE_C_ROW_SCATTERED(28,29,30,31)
JMP(END)
LABEL(SCATTERBZ)
UPDATE_C_BZ_ROW_SCATTERED( 8, 9,10,11)
UPDATE_C_BZ_ROW_SCATTERED(12,13,14,15)
UPDATE_C_BZ_ROW_SCATTERED(16,17,18,19)
UPDATE_C_BZ_ROW_SCATTERED(20,21,22,23)
UPDATE_C_BZ_ROW_SCATTERED(24,25,26,27)
UPDATE_C_BZ_ROW_SCATTERED(28,29,30,31)
LABEL(END) LABEL(END)
@@ -560,8 +524,7 @@ void bli_sgemm_skx_asm_32x12_l2(
[beta] "m" (beta), [beta] "m" (beta),
[c] "m" (c), [c] "m" (c),
[rs_c] "m" (rs_c), [rs_c] "m" (rs_c),
[cs_c] "m" (cs_c), [cs_c] "m" (cs_c)
[offsetPtr] "m" (offsetPtr)
: // register clobber list : // register clobber list
"rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12", "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12",
"r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5",
@@ -570,4 +533,6 @@ void bli_sgemm_skx_asm_32x12_l2(
"zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29",
"zmm30", "zmm31", "memory" "zmm30", "zmm31", "memory"
) )
GEMM_UKR_FLUSH_CT( s );
} }

View File

@@ -42,6 +42,8 @@
\ \
void PASTEMAC3(ch,opname,arch,suf) \ void PASTEMAC3(ch,opname,arch,suf) \
( \ ( \
dim_t m, \
dim_t n, \
dim_t k, \ dim_t k, \
ctype* restrict alpha, \ ctype* restrict alpha, \
ctype* restrict a, \ ctype* restrict a, \
@@ -59,9 +61,6 @@ void PASTEMAC3(ch,opname,arch,suf) \
\ \
const inc_t packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \ const inc_t packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \ const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
\
const dim_t m = mr; \
const dim_t n = nr; \
\ \
const inc_t cs_a = packmr; \ const inc_t cs_a = packmr; \
\ \

View File

@@ -87,6 +87,8 @@ PASTEMAC(d,fprintm)( stdout, "gemmtrsm_ukr: b11", mr, 2*nr, \
/* upper: b11 = alpha * b11 - a12 * b21; */ \ /* upper: b11 = alpha * b11 - a12 * b21; */ \
gemm_ukr \ gemm_ukr \
( \ ( \
mr, \
nr, \
k, \ k, \
minus_one, \ minus_one, \
a1x, \ a1x, \

View File

@@ -44,6 +44,8 @@
\ \
void PASTEMAC3(ch,opname,arch,suf) \ void PASTEMAC3(ch,opname,arch,suf) \
( \ ( \
dim_t m, \
dim_t n, \
dim_t k, \ dim_t k, \
ctype* restrict alpha, \ ctype* restrict alpha, \
ctype* restrict a, \ ctype* restrict a, \
@@ -107,8 +109,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
\ \
if ( PASTEMAC(ch,eq0)( *beta ) ) \ if ( PASTEMAC(ch,eq0)( *beta ) ) \
{ \ { \
for ( dim_t i = 0; i < mr; ++i ) \ for ( dim_t i = 0; i < m; ++i ) \
for ( dim_t j = 0; j < nr; ++j ) \ for ( dim_t j = 0; j < n; ++j ) \
PASTEMAC(ch,copys) \ PASTEMAC(ch,copys) \
( \ ( \
ab[ i*rs_ab + j*cs_ab ], \ ab[ i*rs_ab + j*cs_ab ], \
@@ -117,8 +119,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
} \ } \
else \ else \
{ \ { \
for ( dim_t i = 0; i < mr; ++i ) \ for ( dim_t i = 0; i < m; ++i ) \
for ( dim_t j = 0; j < nr; ++j ) \ for ( dim_t j = 0; j < n; ++j ) \
PASTEMAC(ch,xpbys) \ PASTEMAC(ch,xpbys) \
( \ ( \
ab[ i*rs_ab + j*cs_ab ], \ ab[ i*rs_ab + j*cs_ab ], \
@@ -133,8 +135,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
\ \
if ( PASTEMAC(ch,eq0)( *beta ) ) \ if ( PASTEMAC(ch,eq0)( *beta ) ) \
{ \ { \
for ( dim_t j = 0; j < nr; ++j ) \ for ( dim_t j = 0; j < n; ++j ) \
for ( dim_t i = 0; i < mr; ++i ) \ for ( dim_t i = 0; i < m; ++i ) \
PASTEMAC(ch,copys) \ PASTEMAC(ch,copys) \
( \ ( \
ab[ i*rs_ab + j*cs_ab ], \ ab[ i*rs_ab + j*cs_ab ], \
@@ -143,8 +145,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
} \ } \
else \ else \
{ \ { \
for ( dim_t j = 0; j < nr; ++j ) \ for ( dim_t j = 0; j < n; ++j ) \
for ( dim_t i = 0; i < mr; ++i ) \ for ( dim_t i = 0; i < m; ++i ) \
PASTEMAC(ch,xpbys) \ PASTEMAC(ch,xpbys) \
( \ ( \
ab[ i*rs_ab + j*cs_ab ], \ ab[ i*rs_ab + j*cs_ab ], \
@@ -171,6 +173,8 @@ GENTFUNC( dcomplex, z, gemm, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, 4, 4 )
\ \
void PASTEMAC3(ch,opname,arch,suf) \ void PASTEMAC3(ch,opname,arch,suf) \
( \ ( \
dim_t m, \
dim_t n, \
dim_t k, \ dim_t k, \
ctype* restrict alpha, \ ctype* restrict alpha, \
ctype* restrict a, \ ctype* restrict a, \
@@ -188,9 +192,6 @@ void PASTEMAC3(ch,opname,arch,suf) \
\ \
const inc_t packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \ const inc_t packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \ const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
\
const dim_t m = mr; \
const dim_t n = nr; \
\ \
const inc_t cs_a = packmr; \ const inc_t cs_a = packmr; \
\ \

View File

@@ -52,6 +52,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
{ \ { \
const num_t dt = PASTEMAC(ch,type); \ const num_t dt = PASTEMAC(ch,type); \
\ \
const inc_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \
const inc_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \ const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
\ \
const inc_t rs_b = packnr; \ const inc_t rs_b = packnr; \
@@ -68,6 +70,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
/* upper: b11 = alpha * b11 - a12 * b21; */ \ /* upper: b11 = alpha * b11 - a12 * b21; */ \
gemm_ukr \ gemm_ukr \
( \ ( \
mr, \
nr, \
k, \ k, \
minus_one, \ minus_one, \
a1x, \ a1x, \

View File

@@ -39,6 +39,8 @@
\ \
void PASTEMAC3(ch,opname,arch,suf) \ void PASTEMAC3(ch,opname,arch,suf) \
( \ ( \
dim_t m, \
dim_t n, \
dim_t k, \ dim_t k, \
ctype* restrict alpha, \ ctype* restrict alpha, \
ctype* restrict a, \ ctype* restrict a, \
@@ -59,6 +61,9 @@ void PASTEMAC3(ch,opname,arch,suf) \
\ \
const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \
const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \
\
const dim_t mr_r = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \
const dim_t nr_r = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \
\ \
const dim_t k2 = 2 * k; \ const dim_t k2 = 2 * k; \
\ \
@@ -118,6 +123,11 @@ void PASTEMAC3(ch,opname,arch,suf) \
else if ( bli_is_gen_stored( rs_c, cs_c ) ) using_ct = TRUE; \ else if ( bli_is_gen_stored( rs_c, cs_c ) ) using_ct = TRUE; \
else using_ct = FALSE; \ else using_ct = FALSE; \
\ \
\
/* If we are not computing a full micro-tile, then we must write to
ct and then accumulate to c afterwards. */ \
if ( mr != m || nr != n ) using_ct = TRUE; \
\
\ \
if ( using_ct ) \ if ( using_ct ) \
{ \ { \
@@ -149,6 +159,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
/* c = beta * c + alpha_r * a * b; */ \ /* c = beta * c + alpha_r * a * b; */ \
rgemm_ukr \ rgemm_ukr \
( \ ( \
mr_r, \
nr_r, \
k2, \ k2, \
alpha_r, \ alpha_r, \
a_r, \ a_r, \
@@ -164,8 +176,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
/* Accumulate the final result in ct back to c. */ \ /* Accumulate the final result in ct back to c. */ \
if ( PASTEMAC(ch,eq1)( *beta ) ) \ if ( PASTEMAC(ch,eq1)( *beta ) ) \
{ \ { \
for ( j = 0; j < nr; ++j ) \ for ( j = 0; j < n; ++j ) \
for ( i = 0; i < mr; ++i ) \ for ( i = 0; i < m; ++i ) \
{ \ { \
PASTEMAC(ch,adds)( *(ct + i*rs_ct + j*cs_ct), \ PASTEMAC(ch,adds)( *(ct + i*rs_ct + j*cs_ct), \
*(c + i*rs_c + j*cs_c ) ); \ *(c + i*rs_c + j*cs_c ) ); \
@@ -173,8 +185,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
} \ } \
else if ( PASTEMAC(ch,eq0)( *beta ) ) \ else if ( PASTEMAC(ch,eq0)( *beta ) ) \
{ \ { \
for ( j = 0; j < nr; ++j ) \ for ( j = 0; j < n; ++j ) \
for ( i = 0; i < mr; ++i ) \ for ( i = 0; i < m; ++i ) \
{ \ { \
PASTEMAC(ch,copys)( *(ct + i*rs_ct + j*cs_ct), \ PASTEMAC(ch,copys)( *(ct + i*rs_ct + j*cs_ct), \
*(c + i*rs_c + j*cs_c ) ); \ *(c + i*rs_c + j*cs_c ) ); \
@@ -182,8 +194,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
} \ } \
else \ else \
{ \ { \
for ( j = 0; j < nr; ++j ) \ for ( j = 0; j < n; ++j ) \
for ( i = 0; i < mr; ++i ) \ for ( i = 0; i < m; ++i ) \
{ \ { \
PASTEMAC(ch,xpbys)( *(ct + i*rs_ct + j*cs_ct), \ PASTEMAC(ch,xpbys)( *(ct + i*rs_ct + j*cs_ct), \
*beta, \ *beta, \
@@ -215,6 +227,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
/* c = beta * c + alpha_r * a * b; */ \ /* c = beta * c + alpha_r * a * b; */ \
rgemm_ukr \ rgemm_ukr \
( \ ( \
mr_r, \
nr_r, \
k2, \ k2, \
alpha_r, \ alpha_r, \
a_r, \ a_r, \

View File

@@ -153,6 +153,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
upper: bt = -1.0 * a12 * b21; */ \ upper: bt = -1.0 * a12 * b21; */ \
rgemm_ukr \ rgemm_ukr \
( \ ( \
mr_r, \
nr_r, \
k2, \ k2, \
minus_one_r, \ minus_one_r, \
a1x_r, \ a1x_r, \

View File

@@ -0,0 +1,267 @@
#include <cmath>
#include <algorithm>
#include <type_traits>
#include "blis.h"
template <typename T>
struct is_complex : std::false_type {};
template <>
struct is_complex<scomplex> : std::true_type {};
template <>
struct is_complex<dcomplex> : std::true_type {};
template <typename T>
struct is_real : std::integral_constant<bool,!is_complex<T>::value> {};
template <typename T> struct make_complex;
template <> struct make_complex<float > { using type = scomplex; };
template <> struct make_complex<double > { using type = dcomplex; };
template <> struct make_complex<scomplex> { using type = scomplex; };
template <> struct make_complex<dcomplex> { using type = dcomplex; };
template <typename T>
using make_complex_t = typename make_complex<T>::type;
template <typename T> struct make_real;
template <> struct make_real<float > { using type = float; };
template <> struct make_real<double > { using type = double; };
template <> struct make_real<scomplex> { using type = float; };
template <> struct make_real<dcomplex> { using type = double; };
template <typename T>
using make_real_t = typename make_real<T>::type;
template <typename T, bool Cond>
struct make_complex_if : std::conditional<Cond,make_complex_t<T>,make_real_t<T>> {};
template <typename T, bool Cond>
using make_complex_if_t = typename make_complex_if<T,Cond>::type;
template <typename T>
struct real_imag_part
{
real_imag_part& operator=(T) { return *this; }
operator T() const { return T(); }
};
template <typename T>
std::enable_if_t<std::is_arithmetic<typename std::remove_cv<T>::type>::value,T&> real(T& x) { return x; }
template <typename T>
std::enable_if_t<std::is_arithmetic<T>::value,real_imag_part<T>> imag(T x) { return {}; }
inline float& real(scomplex& x) { return x.real; }
inline float& imag(scomplex& x) { return x.imag; }
inline double& real(dcomplex& x) { return x.real; }
inline double& imag(dcomplex& x) { return x.imag; }
inline const float& real(const scomplex& x) { return x.real; }
inline const float& imag(const scomplex& x) { return x.imag; }
inline const double& real(const dcomplex& x) { return x.real; }
inline const double& imag(const dcomplex& x) { return x.imag; }
template <typename T>
std::enable_if_t<is_real<T>::value,T> conj(T x) { return x; }
template <typename T>
std::enable_if_t<is_complex<T>::value,T> conj(const T& x) { return {x.real, -x.imag}; }
template <typename T, typename U, typename=void>
struct convert_impl;
template <typename T, typename U>
struct convert_impl<T, U, std::enable_if_t<is_real<T>::value && is_real<U>::value>>
{
void operator()(T x, U& y) const { y = x; }
};
template <typename T, typename U>
struct convert_impl<T, U, std::enable_if_t<is_real<T>::value && is_complex<U>::value>>
{
void operator()(T x, U& y) const { y.real = x; y.imag = 0; }
};
template <typename T, typename U>
struct convert_impl<T, U, std::enable_if_t<is_complex<T>::value && is_real<U>::value>>
{
void operator()(T x, U& y) const { y = x.real; }
};
template <typename T, typename U>
struct convert_impl<T, U, std::enable_if_t<is_complex<T>::value && is_complex<U>::value>>
{
void operator()(T x, U& y) const { y.real = x.real; y.imag = x.imag; }
};
template <typename U, typename T>
U convert(T x)
{
U y;
convert_impl<T,U>{}(x,y);
return y;
}
template <typename U, typename T>
auto convert_prec(T x) -> make_complex_if_t<U,is_complex<T>::value>
{
return convert<make_complex_if_t<U,is_complex<T>::value>>(x);
}
#define COMPLEX_MATH_OPS(rtype, ctype) \
\
inline bool operator==(rtype x, ctype y) \
{ \
return x == y.real && y.imag == 0; \
} \
\
inline bool operator==(ctype x, rtype y) \
{ \
return y == x.real && x.imag == 0; \
} \
\
inline bool operator==(ctype x, ctype y) \
{ \
return x.real == y.real && \
x.imag == y.imag; \
} \
\
inline ctype operator-(ctype x) \
{ \
return {-x.real, -x.imag}; \
} \
\
inline ctype operator+(rtype x, ctype y) \
{ \
return {x+y.real, y.imag}; \
} \
\
inline ctype operator+(ctype x, rtype y) \
{ \
return {y+x.real, x.imag}; \
} \
\
inline ctype operator+(ctype x, ctype y) \
{ \
return {x.real+y.real, x.imag+y.imag}; \
} \
\
inline ctype operator-(rtype x, ctype y) \
{ \
return {x-y.real, -y.imag}; \
} \
\
inline ctype operator-(ctype x, rtype y) \
{ \
return {x.real-y, x.imag}; \
} \
\
inline ctype operator-(ctype x, ctype y) \
{ \
return {x.real-y.real, x.imag-y.imag}; \
} \
\
inline ctype operator*(rtype x, ctype y) \
{ \
return {x*y.real, x*y.imag}; \
} \
\
inline ctype operator*(ctype x, rtype y) \
{ \
return {y*x.real, y*x.imag}; \
} \
\
inline ctype operator*(ctype x, ctype y) \
{ \
return {x.real*y.real - x.imag*y.imag, \
x.real*y.imag + x.imag*y.real}; \
} \
\
inline ctype operator/(rtype x, ctype y) \
{ \
auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \
auto n = std::ilogb(scale); \
auto yrs = std::scalbn(y.real, -n); \
auto yis = std::scalbn(y.imag, -n); \
auto denom = y.real*yrs + y.imag*yis; \
return {x*yrs/denom, -x*yis/denom}; \
} \
\
inline ctype operator/(ctype x, rtype y) \
{ \
return {x.real/y, x.imag/y}; \
} \
\
inline ctype operator/(ctype x, ctype y) \
{ \
auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \
auto n = std::ilogb(scale); \
auto yrs = std::scalbn(y.real, -n); \
auto yis = std::scalbn(y.imag, -n); \
auto denom = y.real*yrs + y.imag*yis; \
return {(x.real*yrs + x.imag*yis)/denom, \
(x.imag*yrs - x.real*yis)/denom}; \
} \
\
inline ctype& operator+=(ctype& x, rtype y) \
{ \
x.real += y; \
return x; \
} \
\
inline ctype& operator+=(ctype& x, ctype y) \
{ \
x.real += y.real; x.imag += y.imag; \
return x; \
} \
\
inline ctype& operator-=(ctype& x, rtype y) \
{ \
x.real -= y; \
return x; \
} \
\
inline ctype& operator-=(ctype& x, ctype y) \
{ \
x.real -= y.real; x.imag -= y.imag; \
return x; \
} \
\
inline ctype& operator*=(ctype& x, rtype y) \
{ \
x.real *= y; x.imag *= y; \
return x; \
} \
\
inline ctype& operator*=(ctype& x, ctype y) \
{ \
x = x * y; \
return x; \
} \
\
inline ctype& operator/=(ctype& x, rtype y) \
{ \
x.real /= y; x.imag /= y; \
return x; \
} \
\
inline ctype& operator/=(ctype& x, ctype y) \
{ \
x = x / y; \
return x; \
}
COMPLEX_MATH_OPS(float, scomplex);
COMPLEX_MATH_OPS(double, dcomplex);

View File

@@ -0,0 +1,186 @@
#include "syrk_diagonal_ref.h"
/*
* Structure which includes all additional information beyond what is
* already stored in the obj_t structure.
*
* This structure is **read-only** during the operation!
*/
typedef struct packm_diag_params_t
{
packm_blk_var1_params_t super;
void* d;
inc_t incd;
} packm_diag_params_t;
/*
* Declare the pack kernel type and set up and array of
* packing kernels, one for each data type.
*/
#undef GENTFUNC
#define GENTFUNC(ctype,ch,op) \
void PASTEMAC(ch,op) \
( \
struc_t struca, \
diag_t diaga, \
uplo_t uploa, \
conj_t conja, \
pack_t schema, \
bool invdiag, \
dim_t panel_dim, \
dim_t panel_len, \
dim_t panel_dim_max, \
dim_t panel_len_max, \
dim_t panel_dim_off, \
dim_t panel_len_off, \
void* restrict kappa, \
void* restrict a, inc_t inca, inc_t lda, \
void* restrict p, inc_t ldp, \
inc_t is_p, \
cntx_t* cntx, \
void* params \
) \
{ \
packm_diag_params_t* params_cast = params; \
ctype* restrict a_cast = a; \
ctype* restrict p_cast = p; \
ctype* restrict d_cast = params_cast->d; \
inc_t incd = params_cast->incd; \
ctype kappa_cast = *( ctype* )kappa; \
\
if ( schema != BLIS_PACKED_ROW_PANELS && \
schema != BLIS_PACKED_COL_PANELS ) \
bli_abort(); \
\
/* Apply the offset */ \
d_cast += panel_len_off * incd; \
\
if ( conja ) \
{ \
for ( dim_t j = 0; j < panel_len; j++ ) \
{ \
ctype kappa_d; \
PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \
\
for (dim_t i = 0;i < panel_dim;i++) \
PASTEMAC(ch,scal2js)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \
\
for (dim_t i = panel_dim;i < panel_dim_max;i++) \
PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \
} \
} \
else \
{ \
for ( dim_t j = 0; j < panel_len; j++ ) \
{ \
ctype kappa_d; \
PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \
\
for (dim_t i = 0;i < panel_dim;i++) \
PASTEMAC(ch,scal2s)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \
\
for (dim_t i = panel_dim;i < panel_dim_max;i++) \
PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \
} \
} \
\
for (dim_t j = panel_len;j < panel_len_max;j++) \
for (dim_t i = 0;i < panel_dim_max;i++) \
PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \
}
INSERT_GENTFUNC_BASIC0(packm_diag_ukr);
static packm_ker_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr );
/*
* Modify the object A to include information about the diagonal D,
* and imbue it with special function pointers which will take care
* of the actual work of forming (D * A^T)
*/
void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a )
{
memset( params, 0, sizeof(*params) );
// Assumes D is a column vector
params->d = bli_obj_buffer_at_off( d );
params->incd = bli_obj_row_stride( d );
for ( int i = BLIS_DT_LO; i <= BLIS_DT_HI; i++ )
params->super.ukr_fn[i][i] = packm_diag_ukrs[i];
// Attach the parameters to the A object.
bli_obj_set_pack_params( params, a );
}
/*
* Implements C := alpha * A * D * A^T + beta * C
*
* where D is a diagonal matrix with elements taken from the "d" vector.
*/
void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c )
{
obj_t ad; // this is (D * A^T)
packm_diag_params_t params;
bli_obj_alias_to( a, &ad );
bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T
attach_diagonal_factor( &params, d, &ad );
// Does C := alpha * A * B + beta * C using B = (D + A^T)
bli_gemmtnat( alpha, a, &ad, beta, c, NULL, NULL );
}
int main( void )
{
obj_t a;
obj_t d;
obj_t c;
obj_t c_copy;
obj_t norm;
dim_t m = 10;
dim_t k = 10;
for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ )
for ( int upper = 0; upper <= 1; upper++ )
for ( int transa = 0; transa <= 1; transa++ )
for ( int transc = 0; transc <= 1; transc++ )
{
num_t dt = dt_;
uplo_t uplo = upper ? BLIS_UPPER : BLIS_LOWER;
bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a );
bli_obj_create( dt, k, 1, 1, 1, &d );
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c );
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy );
bli_obj_set_struc( BLIS_SYMMETRIC , &c );
bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy );
bli_obj_set_uplo( uplo , &c );
bli_obj_set_uplo( uplo , &c_copy );
bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm );
bli_randm( &a );
bli_randm( &d );
bli_randm( &c );
bli_copym( &c, &c_copy );
syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c );
syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy );
bli_subm( &c_copy, &c );
bli_normfm( &c, &norm );
double normr, normi;
bli_getsc( &norm, &normr, &normi );
printf( "dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n",
dt, upper, transa, transc, normr );
bli_obj_free( &a );
bli_obj_free( &d );
bli_obj_free( &c );
bli_obj_free( &c_copy );
bli_obj_free( &norm );
}
}

View File

@@ -0,0 +1,220 @@
#include "syrk_diagonal_ref.h"
/*
* Forward-declare the pack kernel type and set up and array of
* packing kernels, one for each data type.
*/
template <typename T>
void packm_diag_ukr
(
struc_t /*struca*/,
diag_t /*diaga*/,
uplo_t /*uploa*/,
conj_t conja,
pack_t schema,
bool /*invdiag*/,
dim_t panel_dim,
dim_t panel_len,
dim_t panel_dim_max,
dim_t panel_len_max,
dim_t /*panel_dim_off*/,
dim_t panel_len_off,
void* restrict kappa,
void* restrict a, inc_t inca, inc_t lda,
void* restrict p, inc_t ldp,
inc_t /*is_p*/,
cntx_t* /*cntx*/,
void* params
);
#undef GENTFUNC
#define GENTFUNC(ctype,ch,op) \
static auto PASTEMAC(ch,op) = &packm_diag_ukr<ctype>;
INSERT_GENTFUNC_BASIC0(packm_diag_ukr);
static packm_ker_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr );
/*
* Structure which includes all additional information beyond what is
* already stored in the obj_t structure.
*
* This structure is **read-only** during the operation!
*/
struct packm_diag_params_t : packm_blk_var1_params_t
{
void* d;
inc_t incd;
packm_diag_params_t() {}
packm_diag_params_t( void* d, inc_t incd )
: d(d), incd(incd)
{
for ( int i = BLIS_DT_LO; i <= BLIS_DT_HI; i++ )
ukr_fn[i][i] = packm_diag_ukrs[i];
}
};
/*
* Selecting a different kernel based on the current architecture is
* currently not possible, but is something we plan to support.
*/
template <typename T>
void packm_diag_ukr
(
struc_t /*struca*/,
diag_t /*diaga*/,
uplo_t /*uploa*/,
conj_t conja,
pack_t schema,
bool /*invdiag*/,
dim_t panel_dim,
dim_t panel_len,
dim_t panel_dim_max,
dim_t panel_len_max,
dim_t /*panel_dim_off*/,
dim_t panel_len_off,
void* restrict kappa,
void* restrict a, inc_t inca, inc_t lda,
void* restrict p, inc_t ldp,
inc_t /*is_p*/,
cntx_t* /*cntx*/,
void* params
)
{
auto params_cast = ( packm_diag_params_t* )params;
T* restrict a_cast = ( T* )a;
T* restrict p_cast = ( T* )p;
T* restrict d_cast = ( T* )params_cast->d;
auto incd = params_cast->incd;
auto kappa_cast = *( T* )kappa;
if ( schema != BLIS_PACKED_ROW_PANELS &&
schema != BLIS_PACKED_COL_PANELS )
bli_abort();
/* Apply the offset */
d_cast += panel_len_off * incd;
if ( conja )
{
for ( dim_t j = 0; j < panel_len; j++ )
{
auto kappa_d = kappa_cast * d_cast[ j*incd ];
for (dim_t i = 0;i < panel_dim;i++)
p_cast[ i + j*ldp ] = kappa_d * conj( a_cast[ i*inca + j*lda ] );
for (dim_t i = panel_dim;i < panel_dim_max;i++)
p_cast[ i + j*ldp ] = convert<T>(0.0);
}
}
else
{
for ( dim_t j = 0; j < panel_len; j++ )
{
auto kappa_d = kappa_cast * d_cast[ j*incd ];
for (dim_t i = 0;i < panel_dim;i++)
p_cast[ i + j*ldp ] = kappa_d * a_cast[ i*inca + j*lda ];
for (dim_t i = panel_dim;i < panel_dim_max;i++)
p_cast[ i + j*ldp ] = convert<T>(0.0);
}
}
for (dim_t j = panel_len;j < panel_len_max;j++)
for (dim_t i = 0;i < panel_dim_max;i++)
p_cast[ i + j*ldp ] = convert<T>(0.0);
}
/*
* Modify the object A to include information about the diagonal D,
* and imbue it with special function pointers which will take care
* of the actual work of forming (D * A^T)
*/
void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a )
{
// Assumes D is a column vector
new (params) packm_diag_params_t
(
bli_obj_buffer_at_off( d ),
bli_obj_row_stride( d )
);
// Attach the parameters to the A object.
bli_obj_set_pack_params( params, a );
}
/*
* Implements C := alpha * A * D * A^T + beta * C
*
* where D is a diagonal matrix with elements taken from the "d" vector.
*/
void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c )
{
obj_t ad; // this is (D * A^T)
packm_diag_params_t params;
bli_obj_alias_to( a, &ad );
bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T
attach_diagonal_factor( &params, d, &ad );
// Does C := alpha * A * B + beta * C using B = (D + A^T)
bli_gemmtnat( alpha, a, &ad, beta, c, NULL, NULL );
}
int main()
{
obj_t a;
obj_t d;
obj_t c;
obj_t c_copy;
obj_t norm;
auto m = 10;
auto k = 10;
for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ )
for ( int upper = 0; upper <= 1; upper++ )
for ( int transa = 0; transa <= 1; transa++ )
for ( int transc = 0; transc <= 1; transc++ )
{
auto dt = ( num_t )dt_;
auto uplo = upper ? BLIS_UPPER : BLIS_LOWER;
bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a );
bli_obj_create( dt, k, 1, 1, 1, &d );
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c );
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy );
bli_obj_set_struc( BLIS_SYMMETRIC , &c );
bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy );
bli_obj_set_uplo( uplo , &c );
bli_obj_set_uplo( uplo , &c_copy );
bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm );
bli_randm( &a );
bli_randm( &d );
bli_randm( &c );
bli_copym( &c, &c_copy );
syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c );
syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy );
bli_subm( &c_copy, &c );
bli_normfm( &c, &norm );
double normr, normi;
bli_getsc( &norm, &normr, &normi );
printf("dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n",
dt, upper, transa, transc, normr);
bli_obj_free( &a );
bli_obj_free( &d );
bli_obj_free( &c );
bli_obj_free( &c_copy );
bli_obj_free( &norm );
}
}

View File

@@ -0,0 +1,354 @@
#include "syrk_diagonal_ref.h"
/*
* Structure which includes all additional information beyond what is
* already stored in the obj_t structure.
*
* This structure is **read-only** during the operation!
*/
typedef struct packm_diag_params_t
{
void* d;
inc_t incd;
} packm_diag_params_t;
typedef void (*packm_diag_ukr_vft)
(
bool conja,
dim_t panel_dim,
dim_t panel_len,
dim_t panel_dim_max,
dim_t panel_len_max,
void* restrict kappa,
void* restrict d, inc_t incd,
void* restrict a, inc_t inca, inc_t lda,
void* restrict p, inc_t ldp
);
/*
* Declare the pack kernel type and set up and array of
* packing kernels, one for each data type.
*/
#undef GENTFUNC
#define GENTFUNC(ctype,ch,op) \
void PASTEMAC(ch,op) \
( \
bool conja, \
dim_t panel_dim, \
dim_t panel_len, \
dim_t panel_dim_max, \
dim_t panel_len_max, \
void* restrict kappa, \
void* restrict d, inc_t incd, \
void* restrict a, inc_t inca, inc_t lda, \
void* restrict p, inc_t ldp \
) \
{ \
ctype* restrict a_cast = a; \
ctype* restrict p_cast = p; \
ctype* restrict d_cast = d; \
ctype kappa_cast = *( ctype* )kappa; \
\
if ( conja ) \
{ \
for ( dim_t j = 0; j < panel_len; j++ ) \
{ \
ctype kappa_d; \
PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \
\
for (dim_t i = 0;i < panel_dim;i++) \
PASTEMAC(ch,scal2js)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \
\
for (dim_t i = panel_dim;i < panel_dim_max;i++) \
PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \
} \
} \
else \
{ \
for ( dim_t j = 0; j < panel_len; j++ ) \
{ \
ctype kappa_d; \
PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \
\
for (dim_t i = 0;i < panel_dim;i++) \
PASTEMAC(ch,scal2s)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \
\
for (dim_t i = panel_dim;i < panel_dim_max;i++) \
PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \
} \
} \
\
for (dim_t j = panel_len;j < panel_len_max;j++) \
for (dim_t i = 0;i < panel_dim_max;i++) \
PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \
}
INSERT_GENTFUNC_BASIC0(packm_diag_ukr);
static packm_diag_ukr_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr );
void packm_diag
(
obj_t* a,
obj_t* p,
cntx_t* cntx,
rntm_t* rntm,
cntl_t* cntl,
thrinfo_t* thread
)
{
#if 1
// We begin by copying the fields of A.
bli_obj_alias_to( a, p );
// Get information about data types.
num_t dt = bli_obj_dt( a );
num_t dt_tar = bli_obj_target_dt( a );
num_t dt_scalar = bli_obj_scalar_dt( a );
dim_t dt_size = bli_dt_size( dt );
if ( dt_scalar != dt || dt_tar != dt )
bli_abort();
// Extract various fields from the control tree.
bszid_t bmult_id_m = bli_cntl_packm_params_bmid_m( cntl );
bszid_t bmult_id_n = bli_cntl_packm_params_bmid_n( cntl );
pack_t schema = bli_cntl_packm_params_pack_schema( cntl );
dim_t bmult_m_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_m, cntx );
dim_t bmult_m_pack = bli_cntx_get_blksz_max_dt( dt_tar, bmult_id_m, cntx );
dim_t bmult_n_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_n, cntx );
if ( schema != BLIS_PACKED_ROW_PANELS &&
schema != BLIS_PACKED_COL_PANELS )
bli_abort();
// Store the pack schema to the object.
bli_obj_set_pack_schema( schema, p );
// Clear the conjugation field from the object since matrix packing
// in BLIS is deemed to take care of all conjugation necessary.
bli_obj_set_conj( BLIS_NO_CONJUGATE, p );
// If we are packing micropanels, mark P as dense.
bli_obj_set_uplo( BLIS_DENSE, p );
// Reset the view offsets to (0,0).
bli_obj_set_offs( 0, 0, p );
// Compute the dimensions padded by the dimension multiples. These
// dimensions will be the dimensions of the packed matrices, including
// zero-padding, and will be used by the macro- and micro-kernels.
// We compute them by starting with the effective dimensions of A (now
// in P) and aligning them to the dimension multiples (typically equal
// to register blocksizes). This does waste a little bit of space for
// level-2 operations, but that's okay with us.
dim_t m_p = bli_obj_length( p );
dim_t n_p = bli_obj_width( p );
dim_t m_p_pad = bli_align_dim_to_mult( m_p, bmult_m_def );
dim_t n_p_pad = bli_align_dim_to_mult( n_p, bmult_n_def );
// Save the padded dimensions into the packed object. It is important
// to save these dimensions since they represent the actual dimensions
// of the zero-padded matrix.
bli_obj_set_padded_dims( m_p_pad, n_p_pad, p );
// The "panel stride" of a micropanel packed object is interpreted as
// the distance between the (0,0) element of panel k and the (0,0)
// element of panel k+1. We use the padded width computed above to
// allow for zero-padding (if necessary/desired) along the far end
// of each micropanel (ie: the right edge of the matrix). Zero-padding
// can also occur along the long edge of the last micropanel if the m
// dimension of the matrix is not a whole multiple of MR.
inc_t ps_p = bmult_m_pack * n_p_pad;
/* Compute the total number of iterations we'll need. */
dim_t n_iter = m_p_pad / bmult_m_def;
// Store the strides and panel dimension in P.
bli_obj_set_strides( 1, bmult_m_pack, p );
bli_obj_set_imag_stride( 1, p );
bli_obj_set_panel_dim( bmult_m_def, p );
bli_obj_set_panel_stride( ps_p, p );
bli_obj_set_panel_length( bmult_m_def, p );
bli_obj_set_panel_width( n_p, p );
// Compute the size of the packed buffer.
siz_t size_p = ps_p * n_iter * dt_size;
if ( size_p == 0 ) return;
// Update the buffer address in p to point to the buffer associated
// with the mem_t entry acquired from the memory broker (now cached in
// the control tree node).
char* p_cast = (char*)bli_packm_alloc( size_p, rntm, cntl, thread );
bli_obj_set_buffer( p_cast, p );
#else
// Every thread initializes p and determines the size of memory
// block needed (which gets embedded into the otherwise "blank" mem_t
// entry in the control tree node). Return early if no packing is required.
if ( !bli_packm_init( a, p, cntx, rntm, cntl, thread ) )
return;
num_t dt = bli_obj_dt( a );
dim_t dt_size = bli_dt_size( dt );
bszid_t bmult_id_m = bli_cntl_packm_params_bmid_m( cntl );
dim_t bmult_m_def = bli_cntx_get_blksz_def_dt( dt, bmult_id_m, cntx );
dim_t bmult_m_pack = bli_cntx_get_blksz_max_dt( dt, bmult_id_m, cntx );
dim_t m_p = bli_obj_length( p );
dim_t n_p = bli_obj_width( p );
dim_t m_p_pad = bli_obj_padded_length( p );
dim_t n_p_pad = bli_obj_padded_width( p );
dim_t n_iter = m_p_pad / bmult_m_def;
char* p_cast = bli_obj_buffer( p );
inc_t ps_p = bli_obj_panel_stride( p );
#endif
char* a_cast = bli_obj_buffer_at_off( a );
inc_t inca = bli_obj_row_stride( a );
inc_t lda = bli_obj_col_stride( a );
dim_t panel_len_off = bli_obj_col_off( a );
conj_t conja = bli_obj_conj_status( a );
packm_diag_params_t* params = bli_obj_pack_params( a );
char* d_cast = params->d;
inc_t incd = params->incd;
obj_t kappa_local;
char* kappa_cast = bli_packm_scalar( &kappa_local, p );
packm_diag_ukr_vft packm_ker_cast = packm_diag_ukrs[ dt ];
/* Query the number of threads and thread ids from the current thread's
packm thrinfo_t node. */
const dim_t nt = bli_thread_n_way( thread );
const dim_t tid = bli_thread_work_id( thread );
/* Determine the thread range and increment using the current thread's
packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir()
will depend on whether slab or round-robin partitioning was requested
at configure-time. */
dim_t it_start, it_end, it_inc;
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc );
/* Iterate over every logical micropanel in the source matrix. */
for ( dim_t it = 0; it < n_iter; it += 1 )
{
dim_t panel_dim_i = bli_min( bmult_m_def, m_p - it*bmult_m_def );
char* d_begin = d_cast + panel_len_off*incd*dt_size;
char* a_begin = a_cast + it* bmult_m_def*inca*dt_size;
char* p_begin = p_cast + it* ps_p*dt_size;
if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) )
{
packm_ker_cast
(
conja,
panel_dim_i,
n_p,
bmult_m_def,
n_p_pad,
kappa_cast,
d_begin, incd,
a_begin, inca, lda,
p_begin, bmult_m_pack
);
}
}
}
/*
* Modify the object A to include information about the diagonal D,
* and imbue it with special function pointers which will take care
* of the actual work of forming (D * A^T)
*/
void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a )
{
// Assumes D is a column vector
params->d = bli_obj_buffer_at_off( d );
params->incd = bli_obj_row_stride( d );
// Set the custom pack function.
bli_obj_set_pack_fn( packm_diag, a );
// Attach the parameters to the A object.
bli_obj_set_pack_params( params, a );
}
/*
* Implements C := alpha * A * D * A^T + beta * C
*
* where D is a diagonal matrix with elements taken from the "d" vector.
*/
void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c )
{
obj_t ad; // this is (D * A^T)
packm_diag_params_t params;
bli_obj_alias_to( a, &ad );
bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T
attach_diagonal_factor( &params, d, &ad );
// Does C := alpha * A * B + beta * C using B = (D + A^T)
bli_gemmt( alpha, a, &ad, beta, c );
}
int main( void )
{
obj_t a;
obj_t d;
obj_t c;
obj_t c_copy;
obj_t norm;
dim_t m = 10;
dim_t k = 10;
for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ )
for ( int upper = 0; upper <= 1; upper++ )
for ( int transa = 0; transa <= 1; transa++ )
for ( int transc = 0; transc <= 1; transc++ )
{
num_t dt = dt_;
uplo_t uplo = upper ? BLIS_UPPER : BLIS_LOWER;
bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a );
bli_obj_create( dt, k, 1, 1, 1, &d );
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c );
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy );
bli_obj_set_struc( BLIS_SYMMETRIC , &c );
bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy );
bli_obj_set_uplo( uplo , &c );
bli_obj_set_uplo( uplo , &c_copy );
bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm );
bli_randm( &a );
bli_randm( &d );
bli_randm( &c );
bli_copym( &c, &c_copy );
syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c );
syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy );
bli_subm( &c_copy, &c );
bli_normfm( &c, &norm );
double normr, normi;
bli_getsc( &norm, &normr, &normi );
printf( "dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n",
dt, upper, transa, transc, normr );
bli_obj_free( &a );
bli_obj_free( &d );
bli_obj_free( &c );
bli_obj_free( &c_copy );
bli_obj_free( &norm );
}
}

View File

@@ -0,0 +1,338 @@
#include "syrk_diagonal_ref.h"
/*
* Forward-declare the pack kernel type and set up and array of
* packing kernels, one for each data type.
*/
template <typename T>
void packm_diag_ukr
(
bool conja,
dim_t panel_dim,
dim_t panel_len,
dim_t panel_dim_max,
dim_t panel_len_max,
void* restrict kappa,
void* restrict d, inc_t incd,
void* restrict a, inc_t inca, inc_t lda,
void* restrict p, inc_t ldp
);
#undef GENTFUNC
#define GENTFUNC(ctype,ch,op) \
static auto PASTEMAC(ch,op) = &packm_diag_ukr<ctype>;
INSERT_GENTFUNC_BASIC0(packm_diag_ukr);
using packm_diag_ukr_vft = decltype(&packm_diag_ukr<void>);
static packm_diag_ukr_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr );
/*
* Structure which includes all additional information beyond what is
* already stored in the obj_t structure.
*
* This structure is **read-only** during the operation!
*/
struct packm_diag_params_t
{
void* d;
inc_t incd;
packm_diag_params_t() {}
packm_diag_params_t( void* d, inc_t incd )
: d(d), incd(incd) {}
};
/*
* Selecting a different kernel based on the current architecture is
* currently not possible, but is something we plan to support.
*/
template <typename T>
void packm_diag_ukr
(
bool conja,
dim_t panel_dim,
dim_t panel_len,
dim_t panel_dim_max,
dim_t panel_len_max,
void* restrict kappa,
void* restrict d, inc_t incd,
void* restrict a, inc_t inca, inc_t lda,
void* restrict p, inc_t ldp
)
{
T* restrict a_cast = ( T* )a;
T* restrict p_cast = ( T* )p;
T* restrict d_cast = ( T* )d;
auto kappa_cast = *( T* )kappa;
if ( conja )
{
for ( dim_t j = 0; j < panel_len; j++ )
{
auto kappa_d = kappa_cast * d_cast[ j*incd ];
for (dim_t i = 0;i < panel_dim;i++)
p_cast[ i + j*ldp ] = kappa_d * conj( a_cast[ i*inca + j*lda ] );
for (dim_t i = panel_dim;i < panel_dim_max;i++)
p_cast[ i + j*ldp ] = convert<T>(0.0);
}
}
else
{
for ( dim_t j = 0; j < panel_len; j++ )
{
auto kappa_d = kappa_cast * d_cast[ j*incd ];
for (dim_t i = 0;i < panel_dim;i++)
p_cast[ i + j*ldp ] = kappa_d * a_cast[ i*inca + j*lda ];
for (dim_t i = panel_dim;i < panel_dim_max;i++)
p_cast[ i + j*ldp ] = convert<T>(0.0);
}
}
for (dim_t j = panel_len;j < panel_len_max;j++)
for (dim_t i = 0;i < panel_dim_max;i++)
p_cast[ i + j*ldp ] = convert<T>(0.0);
}
void packm_diag
(
obj_t* a,
obj_t* p,
cntx_t* cntx,
rntm_t* rntm,
cntl_t* cntl,
thrinfo_t* thread
)
{
// We begin by copying the fields of A.
bli_obj_alias_to( a, p );
// Get information about data types.
num_t dt = bli_obj_dt( a );
num_t dt_tar = bli_obj_target_dt( a );
num_t dt_scalar = bli_obj_scalar_dt( a );
dim_t dt_size = bli_dt_size( dt );
if ( dt_scalar != dt || dt_tar != dt )
bli_abort();
// Extract various fields from the control tree.
bszid_t bmult_id_m = bli_cntl_packm_params_bmid_m( cntl );
bszid_t bmult_id_n = bli_cntl_packm_params_bmid_n( cntl );
pack_t schema = bli_cntl_packm_params_pack_schema( cntl );
dim_t bmult_m_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_m, cntx );
dim_t bmult_m_pack = bli_cntx_get_blksz_max_dt( dt_tar, bmult_id_m, cntx );
dim_t bmult_n_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_n, cntx );
if ( schema != BLIS_PACKED_ROW_PANELS &&
schema != BLIS_PACKED_COL_PANELS )
bli_abort();
// Store the pack schema to the object.
bli_obj_set_pack_schema( schema, p );
// Clear the conjugation field from the object since matrix packing
// in BLIS is deemed to take care of all conjugation necessary.
bli_obj_set_conj( BLIS_NO_CONJUGATE, p );
// If we are packing micropanels, mark P as dense.
bli_obj_set_uplo( BLIS_DENSE, p );
// Reset the view offsets to (0,0).
bli_obj_set_offs( 0, 0, p );
// Compute the dimensions padded by the dimension multiples. These
// dimensions will be the dimensions of the packed matrices, including
// zero-padding, and will be used by the macro- and micro-kernels.
// We compute them by starting with the effective dimensions of A (now
// in P) and aligning them to the dimension multiples (typically equal
// to register blocksizes). This does waste a little bit of space for
// level-2 operations, but that's okay with us.
dim_t m_p = bli_obj_length( p );
dim_t n_p = bli_obj_width( p );
dim_t m_p_pad = bli_align_dim_to_mult( m_p, bmult_m_def );
dim_t n_p_pad = bli_align_dim_to_mult( n_p, bmult_n_def );
// Save the padded dimensions into the packed object. It is important
// to save these dimensions since they represent the actual dimensions
// of the zero-padded matrix.
bli_obj_set_padded_dims( m_p_pad, n_p_pad, p );
// The "panel stride" of a micropanel packed object is interpreted as
// the distance between the (0,0) element of panel k and the (0,0)
// element of panel k+1. We use the padded width computed above to
// allow for zero-padding (if necessary/desired) along the far end
// of each micropanel (ie: the right edge of the matrix). Zero-padding
// can also occur along the long edge of the last micropanel if the m
// dimension of the matrix is not a whole multiple of MR.
inc_t ps_p = bmult_m_pack * n_p_pad;
/* Compute the total number of iterations we'll need. */
dim_t n_iter = m_p_pad / bmult_m_def;
// Store the strides and panel dimension in P.
bli_obj_set_strides( 1, bmult_m_pack, p );
bli_obj_set_imag_stride( 1, p );
bli_obj_set_panel_dim( bmult_m_def, p );
bli_obj_set_panel_stride( ps_p, p );
bli_obj_set_panel_length( bmult_m_def, p );
bli_obj_set_panel_width( n_p, p );
// Compute the size of the packed buffer.
siz_t size_p = ps_p * n_iter * dt_size;
if ( size_p == 0 ) return;
// Update the buffer address in p to point to the buffer associated
// with the mem_t entry acquired from the memory broker (now cached in
// the control tree node).
char* p_cast = (char*)bli_packm_alloc( size_p, rntm, cntl, thread );
bli_obj_set_buffer( p_cast, p );
char* a_cast = (char*)bli_obj_buffer_at_off( a );
inc_t inca = bli_obj_row_stride( a );
inc_t lda = bli_obj_col_stride( a );
dim_t panel_len_off = bli_obj_col_off( a );
conj_t conja = bli_obj_conj_status( a );
auto params = (packm_diag_params_t*)bli_obj_pack_params( a );
char* d_cast = (char*)params->d;
inc_t incd = params->incd;
obj_t kappa_local;
char* kappa_cast = (char*)bli_packm_scalar( &kappa_local, p );
auto packm_ker_cast = packm_diag_ukrs[ dt ];
/* Query the number of threads and thread ids from the current thread's
packm thrinfo_t node. */
const dim_t nt = bli_thread_n_way( thread );
const dim_t tid = bli_thread_work_id( thread );
/* Determine the thread range and increment using the current thread's
packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir()
will depend on whether slab or round-robin partitioning was requested
at configure-time. */
dim_t it_start, it_end, it_inc;
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc );
/* Iterate over every logical micropanel in the source matrix. */
for ( dim_t it = 0; it < n_iter; it += 1 )
{
dim_t panel_dim_i = bli_min( bmult_m_def, m_p - it*bmult_m_def );
char* d_begin = d_cast + panel_len_off*incd*dt_size;
char* a_begin = a_cast + it* bmult_m_def*inca*dt_size;
char* p_begin = p_cast + it* ps_p*dt_size;
if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) )
{
packm_ker_cast( conja,
panel_dim_i,
n_p,
bmult_m_def,
n_p_pad,
kappa_cast,
d_begin, incd,
a_begin, inca, lda,
p_begin, bmult_m_pack );
}
}
}
/*
* Modify the object A to include information about the diagonal D,
* and imbue it with special function pointers which will take care
* of the actual work of forming (D * A^T)
*/
void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a )
{
// Assumes D is a column vector
new (params) packm_diag_params_t
(
bli_obj_buffer_at_off( d ),
bli_obj_row_stride( d )
);
// Set the custom pack function.
bli_obj_set_pack_fn( packm_diag, a );
// Attach the parameters to the A object.
bli_obj_set_pack_params( params, a );
}
/*
* Implements C := alpha * A * D * A^T + beta * C
*
* where D is a diagonal matrix with elements taken from the "d" vector.
*/
void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c )
{
obj_t ad; // this is (D * A^T)
packm_diag_params_t params;
bli_obj_alias_to( a, &ad );
bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T
attach_diagonal_factor( &params, d, &ad );
// Does C := alpha * A * B + beta * C using B = (D + A^T)
bli_gemmt( alpha, a, &ad, beta, c );
}
int main()
{
obj_t a;
obj_t d;
obj_t c;
obj_t c_copy;
obj_t norm;
auto m = 10;
auto k = 10;
for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ )
for ( int upper = 0; upper <= 1; upper++ )
for ( int transa = 0; transa <= 1; transa++ )
for ( int transc = 0; transc <= 1; transc++ )
{
auto dt = ( num_t )dt_;
auto uplo = upper ? BLIS_UPPER : BLIS_LOWER;
bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a );
bli_obj_create( dt, k, 1, 1, 1, &d );
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c );
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy );
bli_obj_set_struc( BLIS_SYMMETRIC , &c );
bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy );
bli_obj_set_uplo( uplo , &c );
bli_obj_set_uplo( uplo , &c_copy );
bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm );
bli_randm( &a );
bli_randm( &d );
bli_randm( &c );
bli_copym( &c, &c_copy );
syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c );
syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy );
bli_subm( &c_copy, &c );
bli_normfm( &c, &norm );
double normr, normi;
bli_getsc( &norm, &normr, &normi );
printf("dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n",
dt, upper, transa, transc, normr);
bli_obj_free( &a );
bli_obj_free( &d );
bli_obj_free( &c );
bli_obj_free( &c_copy );
bli_obj_free( &norm );
}
}

View File

@@ -0,0 +1,102 @@
#include "syrk_diagonal_ref.h"
#include "complex_math.hpp"
typedef void (*syrk_diag_ref_vft)
(
uplo_t uplo,
dim_t m,
dim_t k,
void* alpha,
void* a, inc_t rs_a, inc_t cs_a,
void* d, inc_t incd,
void* beta,
void* c, inc_t rs_c, inc_t cs_c
);
template <typename T>
void syrk_diag_ref
(
uplo_t uplo,
dim_t m,
dim_t k,
void* alpha,
void* a, inc_t rs_a, inc_t cs_a,
void* d, inc_t incd,
void* beta,
void* c, inc_t rs_c, inc_t cs_c
)
{
auto alpha_cast = *( T* )alpha;
auto beta_cast = *( T* )beta;
auto a_cast = ( T* )a;
auto d_cast = ( T* )d;
auto c_cast = ( T* )c;
for ( dim_t i = 0; i < m; i++ )
{
dim_t j_min = uplo == BLIS_UPPER ? i : 0;
dim_t j_max = uplo == BLIS_UPPER ? m : i+1;
for ( dim_t j = j_min; j < j_max; j++ )
{
auto ada = convert<T>(0.0);
for ( dim_t p = 0; p < k; p++ )
{
ada += a_cast[ i*rs_a + p*cs_a ] *
d_cast[ p*incd ] *
a_cast[ j*rs_a + p*cs_a ];
}
if ( beta_cast == convert<T>(0.0) )
{
c_cast[ i*rs_c + j*cs_c ] = alpha_cast * ada;
}
else
{
c_cast[ i*rs_c + j*cs_c ] = alpha_cast * ada +
beta_cast * c_cast[ i*rs_c + j*cs_c ];
}
}
}
}
#undef GENTFUNC
#define GENTFUNC(ctype,ch,op) \
static auto PASTEMAC(ch,op) = &syrk_diag_ref<ctype>;
INSERT_GENTFUNC_BASIC0(syrk_diag_ref);
static syrk_diag_ref_vft GENARRAY( syrk_diag_ref_impl, syrk_diag_ref );
void syrk_diag_ref( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c )
{
num_t dt = bli_obj_dt( a );
dim_t m = bli_obj_length_after_trans( a );
dim_t k = bli_obj_width_after_trans( a );
inc_t rs_a = bli_obj_row_stride( a );
inc_t cs_a = bli_obj_col_stride( a );
inc_t rs_c = bli_obj_row_stride( c );
inc_t cs_c = bli_obj_col_stride( c );
inc_t incd = bli_obj_row_stride( d );
if ( bli_obj_has_trans( a ) )
bli_swap_incs( &rs_a, &cs_a );
if ( bli_obj_has_trans( c ) )
bli_swap_incs( &rs_c, &cs_c );
syrk_diag_ref_impl[ dt ]
(
bli_obj_uplo( c ),
m, k,
bli_obj_buffer_for_1x1( dt, alpha ),
bli_obj_buffer_at_off( a ), rs_a, cs_a,
bli_obj_buffer_at_off( d ), incd,
bli_obj_buffer_for_1x1( dt, beta ),
bli_obj_buffer_at_off( c ), rs_c, cs_c
);
}

View File

@@ -0,0 +1,8 @@
#include "blis.h"
#ifdef __cplusplus
#include "complex_math.hpp"
extern "C"
#endif
void syrk_diag_ref( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c );

View File

@@ -0,0 +1,267 @@
#include <cmath>
#include <algorithm>
#include <type_traits>
#include "blis.h"
template <typename T>
struct is_complex : std::false_type {};
template <>
struct is_complex<scomplex> : std::true_type {};
template <>
struct is_complex<dcomplex> : std::true_type {};
template <typename T>
struct is_real : std::integral_constant<bool,!is_complex<T>::value> {};
template <typename T> struct make_complex;
template <> struct make_complex<float > { using type = scomplex; };
template <> struct make_complex<double > { using type = dcomplex; };
template <> struct make_complex<scomplex> { using type = scomplex; };
template <> struct make_complex<dcomplex> { using type = dcomplex; };
template <typename T>
using make_complex_t = typename make_complex<T>::type;
template <typename T> struct make_real;
template <> struct make_real<float > { using type = float; };
template <> struct make_real<double > { using type = double; };
template <> struct make_real<scomplex> { using type = float; };
template <> struct make_real<dcomplex> { using type = double; };
template <typename T>
using make_real_t = typename make_real<T>::type;
template <typename T, bool Cond>
struct make_complex_if : std::conditional<Cond,make_complex_t<T>,make_real_t<T>> {};
template <typename T, bool Cond>
using make_complex_if_t = typename make_complex_if<T,Cond>::type;
template <typename T>
struct real_imag_part
{
real_imag_part& operator=(T) { return *this; }
operator T() const { return T(); }
};
template <typename T>
std::enable_if_t<std::is_arithmetic<typename std::remove_cv<T>::type>::value,T&> real(T& x) { return x; }
template <typename T>
std::enable_if_t<std::is_arithmetic<T>::value,real_imag_part<T>> imag(T x) { return {}; }
inline float& real(scomplex& x) { return x.real; }
inline float& imag(scomplex& x) { return x.imag; }
inline double& real(dcomplex& x) { return x.real; }
inline double& imag(dcomplex& x) { return x.imag; }
inline const float& real(const scomplex& x) { return x.real; }
inline const float& imag(const scomplex& x) { return x.imag; }
inline const double& real(const dcomplex& x) { return x.real; }
inline const double& imag(const dcomplex& x) { return x.imag; }
template <typename T>
std::enable_if_t<is_real<T>::value,T> conj(T x) { return x; }
template <typename T>
std::enable_if_t<is_complex<T>::value,T> conj(const T& x) { return {x.real, -x.imag}; }
template <typename T, typename U, typename=void>
struct convert_impl;
template <typename T, typename U>
struct convert_impl<T, U, std::enable_if_t<is_real<T>::value && is_real<U>::value>>
{
void operator()(T x, U& y) const { y = x; }
};
template <typename T, typename U>
struct convert_impl<T, U, std::enable_if_t<is_real<T>::value && is_complex<U>::value>>
{
void operator()(T x, U& y) const { y.real = x; y.imag = 0; }
};
template <typename T, typename U>
struct convert_impl<T, U, std::enable_if_t<is_complex<T>::value && is_real<U>::value>>
{
void operator()(T x, U& y) const { y = x.real; }
};
template <typename T, typename U>
struct convert_impl<T, U, std::enable_if_t<is_complex<T>::value && is_complex<U>::value>>
{
void operator()(T x, U& y) const { y.real = x.real; y.imag = x.imag; }
};
template <typename U, typename T>
U convert(T x)
{
U y;
convert_impl<T,U>{}(x,y);
return y;
}
template <typename U, typename T>
auto convert_prec(T x) -> make_complex_if_t<U,is_complex<T>::value>
{
return convert<make_complex_if_t<U,is_complex<T>::value>>(x);
}
#define COMPLEX_MATH_OPS(rtype, ctype) \
\
inline bool operator==(rtype x, ctype y) \
{ \
return x == y.real && y.imag == 0; \
} \
\
inline bool operator==(ctype x, rtype y) \
{ \
return y == x.real && x.imag == 0; \
} \
\
inline bool operator==(ctype x, ctype y) \
{ \
return x.real == y.real && \
x.imag == y.imag; \
} \
\
inline ctype operator-(ctype x) \
{ \
return {-x.real, -x.imag}; \
} \
\
inline ctype operator+(rtype x, ctype y) \
{ \
return {x+y.real, y.imag}; \
} \
\
inline ctype operator+(ctype x, rtype y) \
{ \
return {y+x.real, x.imag}; \
} \
\
inline ctype operator+(ctype x, ctype y) \
{ \
return {x.real+y.real, x.imag+y.imag}; \
} \
\
inline ctype operator-(rtype x, ctype y) \
{ \
return {x-y.real, -y.imag}; \
} \
\
inline ctype operator-(ctype x, rtype y) \
{ \
return {x.real-y, x.imag}; \
} \
\
inline ctype operator-(ctype x, ctype y) \
{ \
return {x.real-y.real, x.imag-y.imag}; \
} \
\
inline ctype operator*(rtype x, ctype y) \
{ \
return {x*y.real, x*y.imag}; \
} \
\
inline ctype operator*(ctype x, rtype y) \
{ \
return {y*x.real, y*x.imag}; \
} \
\
inline ctype operator*(ctype x, ctype y) \
{ \
return {x.real*y.real - x.imag*y.imag, \
x.real*y.imag + x.imag*y.real}; \
} \
\
inline ctype operator/(rtype x, ctype y) \
{ \
auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \
auto n = std::ilogb(scale); \
auto yrs = std::scalbn(y.real, -n); \
auto yis = std::scalbn(y.imag, -n); \
auto denom = y.real*yrs + y.imag*yis; \
return {x*yrs/denom, -x*yis/denom}; \
} \
\
inline ctype operator/(ctype x, rtype y) \
{ \
return {x.real/y, x.imag/y}; \
} \
\
inline ctype operator/(ctype x, ctype y) \
{ \
auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \
auto n = std::ilogb(scale); \
auto yrs = std::scalbn(y.real, -n); \
auto yis = std::scalbn(y.imag, -n); \
auto denom = y.real*yrs + y.imag*yis; \
return {(x.real*yrs + x.imag*yis)/denom, \
(x.imag*yrs - x.real*yis)/denom}; \
} \
\
inline ctype& operator+=(ctype& x, rtype y) \
{ \
x.real += y; \
return x; \
} \
\
inline ctype& operator+=(ctype& x, ctype y) \
{ \
x.real += y.real; x.imag += y.imag; \
return x; \
} \
\
inline ctype& operator-=(ctype& x, rtype y) \
{ \
x.real -= y; \
return x; \
} \
\
inline ctype& operator-=(ctype& x, ctype y) \
{ \
x.real -= y.real; x.imag -= y.imag; \
return x; \
} \
\
inline ctype& operator*=(ctype& x, rtype y) \
{ \
x.real *= y; x.imag *= y; \
return x; \
} \
\
inline ctype& operator*=(ctype& x, ctype y) \
{ \
x = x * y; \
return x; \
} \
\
inline ctype& operator/=(ctype& x, rtype y) \
{ \
x.real /= y; x.imag /= y; \
return x; \
} \
\
inline ctype& operator/=(ctype& x, ctype y) \
{ \
x = x / y; \
return x; \
}
COMPLEX_MATH_OPS(float, scomplex);
COMPLEX_MATH_OPS(double, dcomplex);

View File

@@ -0,0 +1,988 @@
#include "tcontract_ref.hpp"
#include <algorithm>
#include <numeric>
static constexpr dim_t BS_K = 8;
struct packm_tensor_params_t
{
gint_t ndim_m, ndim_n;
const dim_t *len_m, *len_n;
const inc_t *stride_m, *stride_n;
packm_tensor_params_t() {}
packm_tensor_params_t( gint_t ndim_m, const dim_t* len_m, const inc_t* stride_m,
gint_t ndim_n, const dim_t* len_n, const inc_t* stride_n )
: ndim_m(ndim_m), ndim_n(ndim_n),
len_m(len_m), len_n(len_n),
stride_m(stride_m), stride_n(stride_n) {}
};
using gemm_tensor_params_t = packm_tensor_params_t;
template <typename T>
void packm_ckx_nb
(
bool conja,
dim_t panel_dim,
dim_t panel_len,
dim_t panel_dim_max,
dim_t panel_len_max,
void* kappa,
void* a, inc_t inca, inc_t* bsa, inc_t* scata,
void* p, inc_t ldp
)
{
T* restrict a_cast = ( T* )a;
T* restrict p_cast = ( T* )p;
auto kappa_cast = *( T* )kappa;
if ( conja )
{
for ( auto j0 = 0; j0 < panel_len; j0 += BS_K, bsa += BS_K, scata += BS_K )
{
auto lda = *bsa;
auto panel_len_j = std::min<dim_t>( panel_len-j0, BS_K );
if ( lda )
{
T* restrict aj = a_cast + *scata;
for ( auto j = 0; j < panel_len_j; j++ )
{
for ( auto i = 0; i < panel_dim; i++ )
p_cast[ i ] = kappa_cast * conj( aj[ i*inca + j*lda ] );
for ( auto i = panel_dim; i < panel_dim_max; i++ )
p_cast[ i ] = convert<T>(0.0);
p_cast += ldp;
}
}
else
{
for ( auto j = 0; j < panel_len_j; j++)
{
for ( auto i = 0; i < panel_dim; i++)
p_cast[ i ] = kappa_cast * conj( a_cast[ i*inca + scata[j] ] );
for ( auto i = panel_dim; i < panel_dim_max; i++)
p_cast[ i ] = convert<T>(0.0);
p_cast += ldp;
}
}
}
}
else
{
for ( auto j0 = 0; j0 < panel_len; j0 += BS_K, bsa += BS_K, scata += BS_K )
{
auto lda = *bsa;
auto panel_len_j = std::min<dim_t>( panel_len-j0, BS_K );
if ( lda )
{
T* restrict aj = a_cast + *scata;
for ( auto j = 0; j < panel_len_j; j++ )
{
for ( auto i = 0; i < panel_dim; i++ )
p_cast[ i ] = kappa_cast * aj[ i*inca + j*lda ];
for ( auto i = panel_dim; i < panel_dim_max; i++ )
p_cast[ i ] = convert<T>(0.0);
p_cast += ldp;
}
}
else
{
for ( auto j = 0; j < panel_len_j; j++ )
{
for ( auto i = 0; i < panel_dim; i++ )
p_cast[ i ] = kappa_cast * a_cast[ i*inca + scata[j] ];
for ( auto i = panel_dim; i < panel_dim_max; i++ )
p_cast[ i ] = convert<T>(0.0);
p_cast += ldp;
}
}
}
}
for ( auto j = panel_len; j < panel_len_max; j++)
{
for ( auto i = 0; i < panel_dim_max; i++)
p_cast[ i ] = convert<T>(0.0);
p_cast += ldp;
}
}
template <typename T>
void packm_ckx_ss
(
bool conja,
dim_t panel_dim,
dim_t panel_len,
dim_t panel_dim_max,
dim_t panel_len_max,
void* kappa,
void* a, inc_t* inca, inc_t* scata,
void* p, inc_t ldp
)
{
T* restrict a_cast = ( T* )a;
T* restrict p_cast = ( T* )p;
auto kappa_cast = *( T* )kappa;
if ( conja )
{
for (dim_t j = 0;j < panel_len;j++)
{
for (dim_t i = 0;i < panel_dim;i++)
p_cast[ i ] = kappa_cast * conj( a_cast[ inca[i] + scata[j] ] );
for (dim_t i = panel_dim;i < panel_dim_max;i++)
p_cast[ i ] = convert<T>(0.0);
p_cast += ldp;
}
}
else
{
for (dim_t j = 0;j < panel_len;j++)
{
for (dim_t i = 0;i < panel_dim;i++)
p_cast[ i ] = kappa_cast * a_cast[ inca[i] + scata[j] ];
for (dim_t i = panel_dim;i < panel_dim_max;i++)
p_cast[ i ] = convert<T>(0.0);
p_cast += ldp;
}
}
for (dim_t j = panel_len;j < panel_len_max;j++)
{
for (dim_t i = 0;i < panel_dim_max;i++)
p_cast[ i ] = convert<T>(0.0);
p_cast += ldp;
}
}
#undef GENTFUNC
#define GENTFUNC(ctype,ch,op) \
static auto PASTEMAC(ch,op) = &packm_ckx_nb<ctype>;
INSERT_GENTFUNC_BASIC0(packm_ckx_nb);
#undef GENTFUNC
#define GENTFUNC(ctype,ch,op) \
static auto PASTEMAC(ch,op) = &packm_ckx_ss<ctype>;
INSERT_GENTFUNC_BASIC0(packm_ckx_ss);
static decltype(&packm_ckx_nb<void>) GENARRAY( packm_ckx_nb_ukrs, packm_ckx_nb );
static decltype(&packm_ckx_ss<void>) GENARRAY( packm_ckx_ss_ukrs, packm_ckx_ss );
static void fill_scatter
(
gint_t ndim,
const dim_t* restrict len,
const inc_t* restrict stride,
dim_t BS,
inc_t off,
dim_t size,
inc_t* restrict scat,
inc_t* restrict bs
)
{
if ( size == 0 ) return;
if ( ndim == 0 )
{
*scat = 0;
*bs = 0;
return;
}
if ( ndim == 1 )
{
auto l = *len;
auto s = *stride;
for ( auto i = 0; i < l; i++ )
{
scat[i] = i*s;
bs[i] = s;
}
}
dim_t tot_len = 1;
for ( auto i = 0; i < ndim; i++ )
tot_len *= len[i];
assert(off >= 0);
assert(size >= 0);
assert(off+size <= tot_len);
auto len0 = len[0];
auto stride0 = stride[0];
auto off0 = off % len0;
auto off1 = off / len0;
auto size1 = ( size + off0 + len0 - 1) / len0;
inc_t pos1 = 0;
inc_t idx = 0;
for_each( ndim-1, len+1, off1, size1, pos1, stride+1,
[&]
{
auto pos = pos1 + off0 * stride0;
auto len_i = std::min( len0-off0, size-idx );
for ( auto i = 0; i < len_i; i++ )
{
scat[idx++] = pos;
pos += stride0;
}
off0 = 0;
});
assert(idx == size);
for ( idx = 0; idx < size; idx += BS )
{
auto len_i = std::min( BS, size-idx );
auto s = stride0;
for ( auto i = idx; i < idx+len_i-1; i++)
{
if (scat[i+1]-scat[i] != s)
{
s = 0;
break;
}
}
bs[idx] = s;
}
}
void packm_tensor
(
obj_t* a,
obj_t* p,
cntx_t* cntx,
rntm_t* rntm,
cntl_t* cntl,
thrinfo_t* thread
)
{
// We begin by copying the fields of A.
bli_obj_alias_to( a, p );
// Get information about data types.
auto dt = bli_obj_dt( a );
auto dt_tar = bli_obj_target_dt( a );
auto dt_scalar = bli_obj_scalar_dt( a );
auto dt_size = bli_dt_size( dt );
if ( dt_scalar != dt || dt_tar != dt )
bli_abort();
// Extract various fields from the control tree.
auto bmult_id_m = bli_cntl_packm_params_bmid_m( cntl );
auto bmult_id_n = bli_cntl_packm_params_bmid_n( cntl );
auto schema = bli_cntl_packm_params_pack_schema( cntl );
auto bmult_m_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_m, cntx );
auto bmult_m_pack = bli_cntx_get_blksz_max_dt( dt_tar, bmult_id_m, cntx );
auto bmult_n_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_n, cntx );
if ( schema != BLIS_PACKED_ROW_PANELS &&
schema != BLIS_PACKED_COL_PANELS )
bli_abort();
// Store the pack schema to the object.
bli_obj_set_pack_schema( schema, p );
// Clear the conjugation field from the object since matrix packing
// in BLIS is deemed to take care of all conjugation necessary.
bli_obj_set_conj( BLIS_NO_CONJUGATE, p );
// If we are packing micropanels, mark P as dense.
bli_obj_set_uplo( BLIS_DENSE, p );
// Reset the view offsets to (0,0).
bli_obj_set_offs( 0, 0, p );
// Compute the dimensions padded by the dimension multiples. These
// dimensions will be the dimensions of the packed matrices, including
// zero-padding, and will be used by the macro- and micro-kernels.
// We compute them by starting with the effective dimensions of A (now
// in P) and aligning them to the dimension multiples (typically equal
// to register blocksizes). This does waste a little bit of space for
// level-2 operations, but that's okay with us.
auto m_p = bli_obj_length( p );
auto n_p = bli_obj_width( p );
auto m_p_pad = bli_align_dim_to_mult( m_p, bmult_m_def );
auto n_p_pad = bli_align_dim_to_mult( n_p, bmult_n_def );
// Save the padded dimensions into the packed object. It is important
// to save these dimensions since they represent the actual dimensions
// of the zero-padded matrix.
bli_obj_set_padded_dims( m_p_pad, n_p_pad, p );
// The "panel stride" of a micropanel packed object is interpreted as
// the distance between the (0,0) element of panel k and the (0,0)
// element of panel k+1. We use the padded width computed above to
// allow for zero-padding (if necessary/desired) along the far end
// of each micropanel (ie: the right edge of the matrix). Zero-padding
// can also occur along the long edge of the last micropanel if the m
// dimension of the matrix is not a whole multiple of MR.
auto ps_p = bmult_m_pack * n_p_pad;
/* Compute the total number of iterations we'll need. */
auto n_iter = m_p_pad / bmult_m_def;
// Store the strides and panel dimension in P.
bli_obj_set_strides( 1, bmult_m_pack, p );
bli_obj_set_imag_stride( 1, p );
bli_obj_set_panel_dim( bmult_m_def, p );
bli_obj_set_panel_stride( ps_p, p );
bli_obj_set_panel_length( bmult_m_def, p );
bli_obj_set_panel_width( n_p, p );
// Compute the size of the packed buffer.
auto size_p = ps_p * n_iter * dt_size;
if ( size_p == 0 ) return;
// Compute the size of the scatter and block-scatter vectors to the total.
// It is never necessary to add padding for alignment because:
// 1) ps_p is always even
// 2) dt_size is a power of two >= 4
// 3) the alignment of the scatter vectors is at most 8
auto scat_size = 2 * (m_p + n_p) * sizeof(inc_t);
// Update the buffer address in p to point to the buffer associated
// with the mem_t entry acquired from the memory broker (now cached in
// the control tree node).
auto p_cast = (char*)bli_packm_alloc( size_p + scat_size, rntm, cntl, thread );
bli_obj_set_buffer( p_cast, p );
// Get the addresses of the scatter and block-scatter vectors. These are
// placed directly after the packed matrix buffer.
auto rscat = (inc_t*)(p_cast + size_p);
auto rbs = rscat + m_p;
auto cscat = rbs + m_p;
auto cbs = cscat + n_p;
auto a_cast = (char*)bli_obj_buffer_at_off( a );
auto panel_dim_off = bli_obj_row_off( a );
auto panel_len_off = bli_obj_col_off( a );
auto conja = bli_obj_conj_status( a );
auto params = (packm_tensor_params_t*)bli_obj_pack_params( a );
auto ndim_m = params->ndim_m;
auto ndim_n = params->ndim_n;
auto len_m = params->len_m;
auto len_n = params->len_n;
auto stride_m = params->stride_m;
auto stride_n = params->stride_n;
obj_t kappa_local;
auto kappa_cast = (char*)bli_packm_scalar( &kappa_local, p );
auto packm_nb_ker = packm_ckx_nb_ukrs[ dt ];
auto packm_ss_ker = packm_ckx_ss_ukrs[ dt ];
a_cast -= ( panel_dim_off * stride_m[0] +
panel_len_off * stride_n[0] ) * dt_size;
/* Fill in the scatter and block-scatter vectors. This is done single-threaded for now. */
if ( bli_thread_am_ochief( thread ) )
{
fill_scatter
(
ndim_m,
len_m,
stride_m,
bmult_m_def,
panel_dim_off,
m_p,
rscat,
rbs
);
fill_scatter
(
ndim_n,
len_n,
stride_n,
BS_K,
panel_len_off,
n_p,
cscat,
cbs
);
}
/* Wait for the scatter vectors to be done. */
bli_thread_barrier( thread );
/* Query the number of threads and thread ids from the current thread's
packm thrinfo_t node. */
auto nt = bli_thread_n_way( thread );
auto tid = bli_thread_work_id( thread );
/* Determine the thread range and increment using the current thread's
packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir()
will depend on whether slab or round-robin partitioning was requested
at configure-time. */
dim_t it_start, it_end, it_inc;
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc );
/* Iterate over every logical micropanel in the source matrix. */
for ( auto it = 0; it < n_iter; it += 1 )
if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) )
{
auto panel_dim_i = bli_min( bmult_m_def, m_p - it*bmult_m_def );
auto p_begin = p_cast + it*ps_p*dt_size;
auto inca = rbs[ it*bmult_m_def ];
if ( inca )
{
auto a_begin = a_cast + rscat[ it*bmult_m_def ]*dt_size;
packm_nb_ker( conja,
panel_dim_i,
n_p,
bmult_m_def,
n_p_pad,
kappa_cast,
a_begin, inca, cbs, cscat,
p_begin, bmult_m_pack );
}
else
{
auto a_begin = a_cast;
auto rscat_use = rscat + it*bmult_m_def;
packm_ss_ker( conja,
panel_dim_i,
n_p,
bmult_m_def,
n_p_pad,
kappa_cast,
a_begin, rscat_use, cscat,
p_begin, bmult_m_pack );
}
}
}
#undef GENTFUNC
#define GENTFUNC(ctype,ch,op) \
void PASTEMAC(ch,op) \
( \
dim_t m, \
dim_t n, \
void* x, inc_t rs_x, inc_t cs_x, \
void* b, \
void* y, inc_t* rs_y, inc_t* cs_y \
) \
{ \
ctype* restrict x_cast = (ctype*)x; \
ctype b_cast = *(ctype*)b; \
ctype* restrict y_cast = (ctype*)y; \
\
if ( PASTEMAC(ch,eq0)( b_cast ) ) \
{ \
for ( auto i = 0; i < m; i++ ) \
for ( auto j = 0; j < n; j++ ) \
PASTEMAC(ch,copys)( x_cast[ i*rs_x + j*cs_x ], y_cast[ rs_y[i] + cs_y[j] ] ); \
} \
else \
{ \
for ( auto i = 0; i < m; i++ ) \
for ( auto j = 0; j < n; j++ ) \
PASTEMAC(ch,xpbys)( x_cast[ i*rs_x + j*cs_x ], b_cast, y_cast[ rs_y[i] + cs_y[j] ] ); \
} \
}
INSERT_GENTFUNC_BASIC0(scatter_mxn);
static decltype(&bli_sscatter_mxn) GENARRAY(scatter_mxn, scatter_mxn);
void gemm_tensor
(
obj_t* a,
obj_t* b,
obj_t* c,
cntx_t* cntx,
rntm_t* rntm,
cntl_t* cntl,
thrinfo_t* thread
)
{
auto dt = bli_obj_dt( c );
auto dt_size = bli_dt_size( dt );
auto m = bli_obj_length( c );
auto n = bli_obj_width( c );
auto k = bli_obj_width( a );
auto a_cast = (char*)bli_obj_buffer_at_off( a );
auto pd_a = bli_obj_panel_dim( a );
auto ps_a = bli_obj_panel_stride( a );
auto b_cast = (char*)bli_obj_buffer_at_off( b );
auto pd_b = bli_obj_panel_dim( b );
auto ps_b = bli_obj_panel_stride( b );
auto c_cast = (char*)bli_obj_buffer_at_off( c );
auto rs_c0 = bli_obj_row_stride( c );
auto cs_c0 = bli_obj_col_stride( c );
auto off_m = bli_obj_row_off( c );
auto off_n = bli_obj_col_off( c );
auto params = (gemm_tensor_params_t*)bli_obj_ker_params( c );
auto ndim_m = params->ndim_m;
auto ndim_n = params->ndim_n;
auto len_m = params->len_m;
auto len_n = params->len_n;
auto stride_m = params->stride_m;
auto stride_n = params->stride_n;
if ( rs_c0 != stride_m[0] || cs_c0 != stride_n[0] )
{
std::swap( ndim_m, ndim_n );
std::swap( len_m, len_n );
std::swap( stride_m, stride_n );
}
/* If any dimension is zero, return immediately. */
if ( bli_zero_dim3( m, n, k ) ) return;
c_cast -= ( off_m * stride_m[0] +
off_n * stride_n[0] ) * dt_size;
// Detach and multiply the scalars attached to A and B.
// NOTE: We know that the internal scalars of A and B are already of the
// target datatypes because the necessary typecasting would have already
// taken place during bli_packm_init().
obj_t scalar_a;
obj_t scalar_b;
bli_obj_scalar_detach( a, &scalar_a );
bli_obj_scalar_detach( b, &scalar_b );
bli_mulsc( &scalar_a, &scalar_b );
// Grab the addresses of the internal scalar buffers for the scalar
// merged above and the scalar attached to C.
// NOTE: We know that scalar_b is of type dt due to the above code
// that casts the scalars of A and B to dt via scalar_a and scalar_b,
// and we know that the internal scalar in C is already of the type dt
// due to the casting in the implementation of bli_obj_scalar_attach().
auto alpha_cast = (char*)bli_obj_internal_scalar_buffer( &scalar_b );
auto beta_cast = (char*)bli_obj_internal_scalar_buffer( c );
/* Alias some constants to simpler names. */
auto MR = pd_a;
auto NR = pd_b;
/* Query the context for the micro-kernel address and cast it to its
function pointer type. */
auto gemm_ukr = (gemm_ukr_vft)bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx );
/* Temporary C buffer for edge cases. Note that the strides of this
temporary buffer are set so that they match the storage of the
original C matrix. For example, if C is column-stored, ct will be
column-stored as well. */
char ct[ BLIS_STACK_BUF_MAX_SIZE ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE)));
auto col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx );
auto rs_ct = ( col_pref ? 1 : NR );
auto cs_ct = ( col_pref ? MR : 1 );
auto zero = (char*)bli_obj_buffer_for_const( dt, &BLIS_ZERO );
/*
Assumptions/assertions:
rs_a == 1
cs_a == PACKMR
pd_a == MR
ps_a == stride to next micro-panel of A
rs_b == PACKNR
cs_b == 1
pd_b == NR
ps_b == stride to next micro-panel of B
rs_c == (no assumptions)
cs_c == (no assumptions)
*/
auto scat_size = 2 * (m + n) * sizeof(inc_t);
auto rscat_c = (inc_t*)bli_packm_alloc_ex( scat_size, BLIS_BUFFER_FOR_GEN_USE, rntm, cntl, thread );
auto rbs_c = rscat_c + m;
auto cscat_c = rbs_c + m;
auto cbs_c = cscat_c + n;
/* Fill in the scatter and block-scatter vectors. This is done single-threaded for now. */
if ( bli_thread_am_ochief( thread ) )
{
fill_scatter
(
ndim_m,
len_m,
stride_m,
MR,
off_m,
m,
rscat_c,
rbs_c
);
fill_scatter
(
ndim_n,
len_n,
stride_n,
NR,
off_n,
n,
cscat_c,
cbs_c
);
}
/* Wait for the scatter vectors to be done. */
bli_thread_barrier( thread );
/* Compute number of primary and leftover components of the m and n
dimensions. */
auto n_iter = n / NR;
auto n_left = n % NR;
auto m_iter = m / MR;
auto m_left = m % MR;
if ( n_left ) ++n_iter;
if ( m_left ) ++m_iter;
/* Determine some increments used to step through A, B, and C. */
auto rstep_a = ps_a * dt_size;
auto cstep_b = ps_b * dt_size;
/* Save the virtual microkernel address and the params. */
auxinfo_t aux;
bli_auxinfo_set_ukr( (void*)gemm_ukr, &aux );
bli_auxinfo_set_params( params, &aux );
/* The 'thread' argument points to the thrinfo_t node for the 2nd (jr)
loop around the microkernel. Here we query the thrinfo_t node for the
1st (ir) loop around the microkernel. */
auto caucus = bli_thrinfo_sub_node( thread );
/* Query the number of threads and thread ids for each loop. */
auto jr_nt = bli_thread_n_way( thread );
auto jr_tid = bli_thread_work_id( thread );
auto ir_nt = bli_thread_n_way( caucus );
auto ir_tid = bli_thread_work_id( caucus );
/* Determine the thread range and increment for the 2nd and 1st loops.
NOTE: The definition of bli_thread_range_jrir() will depend on whether
slab or round-robin partitioning was requested at configure-time. */
dim_t jr_start, jr_end;
dim_t ir_start, ir_end;
dim_t jr_inc, ir_inc;
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc );
bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc );
/* Loop over the n dimension (NR columns at a time). */
for ( auto j = jr_start; j < jr_end; j += jr_inc )
{
auto b1 = b_cast + j * cstep_b;
auto n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left );
/* Initialize our next panel of B to be the current panel of B. */
auto b2 = b1;
/* Loop over the m dimension (MR rows at a time). */
for ( auto i = ir_start; i < ir_end; i += ir_inc )
{
auto a1 = a_cast + i * rstep_a;
auto rscat_c1 = rscat_c + i * MR;
auto rbs_c1 = rbs_c + i * MR;
auto cscat_c1 = cscat_c + j * NR;
auto cbs_c1 = cbs_c + j * NR;
auto m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left );
/* Compute the addresses of the next panels of A and B. */
auto a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc );
if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) )
{
a2 = a_cast;
b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc );
if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) )
b2 = b_cast;
}
/* Save addresses of next panels of A and B to the auxinfo_t
object. */
bli_auxinfo_set_next_a( a2, &aux );
bli_auxinfo_set_next_b( b2, &aux );
auto rs_c = *rbs_c1;
auto cs_c = *cbs_c1;
if ( rs_c && cs_c )
{
auto c11 = c_cast + ( *rscat_c1 + *cscat_c1 ) * dt_size;
/* Invoke the gemm micro-kernel. */
gemm_ukr
(
m_cur,
n_cur,
k,
alpha_cast,
a1,
b1,
beta_cast,
c11, rs_c, cs_c,
&aux,
cntx
);
}
else
{
/* Invoke the gemm micro-kernel. */
gemm_ukr
(
MR,
NR,
k,
alpha_cast,
a1,
b1,
zero,
&ct, rs_ct, cs_ct,
&aux,
cntx
);
/* Scatter to C. */
scatter_mxn[ dt ]
(
m_cur, n_cur,
&ct, rs_ct, cs_ct,
beta_cast,
c_cast, rscat_c1, cscat_c1
);
}
}
}
}
static bool has_unit_stride( const std::vector<inc_t>& stride )
{
for ( auto s : stride )
if ( s == 1 )
return true;
return false;
}
void tcontract( num_t dt, const std::vector<dim_t>& m, const std::vector<dim_t>& n, const std::vector<dim_t>& k,
const void* alpha, const void* a, std::vector<inc_t> rs_a, std::vector<inc_t> cs_a,
const void* b, std::vector<inc_t> rs_b, std::vector<inc_t> cs_b,
const void* beta, void* c, std::vector<inc_t> rs_c, std::vector<inc_t> cs_c )
{
if ( rs_a.size() != m.size() ||
rs_b.size() != k.size() ||
rs_c.size() != m.size() )
bli_check_error_code( BLIS_INVALID_ROW_STRIDE );
if ( cs_a.size() != k.size() ||
cs_b.size() != n.size() ||
cs_c.size() != n.size() )
bli_check_error_code( BLIS_INVALID_COL_STRIDE );
dim_t m_mat = 1;
dim_t n_mat = 1;
dim_t k_mat = 1;
for ( auto& i : m ) m_mat *= i;
for ( auto& i : n ) n_mat *= i;
for ( auto& i : k ) k_mat *= i;
auto& stride_m = has_unit_stride( rs_c ) ? rs_c : rs_a;
for ( int i = 1;i < m.size(); i++ )
for ( int j = 0;j < m.size()-i; j++ )
if ( stride_m[j] > stride_m[j+1] )
{
std::swap( rs_a[j], rs_a[j+1] );
std::swap( rs_c[j], rs_c[j+1] );
}
auto& stride_n = has_unit_stride( cs_c ) ? cs_c : cs_b;
for ( int i = 1;i < n.size(); i++ )
for ( int j = 0;j < n.size()-i; j++ )
if ( stride_n[j] > stride_n[j+1] )
{
std::swap( cs_b[j], cs_b[j+1] );
std::swap( cs_c[j], cs_c[j+1] );
}
auto& stride_k = has_unit_stride( cs_a ) ? cs_a : rs_b;
for ( int i = 1;i < k.size(); i++ )
for ( int j = 0;j < k.size()-i; j++ )
if ( stride_k[j] > stride_k[j+1] )
{
std::swap( cs_a[j], cs_a[j+1] );
std::swap( rs_b[j], rs_b[j+1] );
}
if ( rs_a.empty() ) rs_a.push_back( 1 );
if ( cs_a.empty() ) cs_a.push_back( 1 );
if ( rs_b.empty() ) rs_b.push_back( 1 );
if ( cs_b.empty() ) cs_b.push_back( 1 );
if ( rs_c.empty() ) rs_c.push_back( 1 );
if ( cs_c.empty() ) cs_c.push_back( 1 );
obj_t a_o, b_o, c_o;
bli_obj_create_with_attached_buffer( dt, m_mat, k_mat, const_cast<void*>(a), rs_a[0], cs_a[0], &a_o );
bli_obj_create_with_attached_buffer( dt, k_mat, n_mat, const_cast<void*>(b), rs_b[0], cs_b[0], &b_o );
bli_obj_create_with_attached_buffer( dt, m_mat, n_mat, c , rs_c[0], cs_c[0], &c_o );
packm_tensor_params_t params_a( m.size(), m.data(), rs_a.data(),
k.size(), k.data(), cs_a.data() );
packm_tensor_params_t params_b( n.size(), n.data(), cs_b.data(),
k.size(), k.data(), rs_b.data() );
gemm_tensor_params_t params_c( m.size(), m.data(), rs_c.data(),
n.size(), n.data(), cs_c.data() );
bli_obj_set_pack_fn( packm_tensor, &a_o );
bli_obj_set_pack_fn( packm_tensor, &b_o );
bli_obj_set_ker_fn( gemm_tensor, &c_o );
bli_obj_set_pack_params( &params_a, &a_o );
bli_obj_set_pack_params( &params_b, &b_o );
bli_obj_set_ker_params( &params_c, &c_o );
obj_t alpha_o, beta_o;
bli_obj_create_1x1_with_attached_buffer( dt, const_cast<void*>(alpha), &alpha_o );
bli_obj_create_1x1_with_attached_buffer( dt, const_cast<void*>(beta), &beta_o );
rntm_t rntm;
bli_rntm_init_from_global( &rntm );
bli_rntm_disable_l3_sup( &rntm );
bli_gemm_ex( &alpha_o, &a_o, &b_o, &beta_o, &c_o, NULL, &rntm );
}
int main()
{
auto N = 5;
gint_t ndim_a = 4;
gint_t ndim_b = 4;
gint_t ndim_c = 4;
std::vector<dim_t> len_a(ndim_a, N);
std::vector<dim_t> len_b(ndim_b, N);
std::vector<dim_t> len_c(ndim_c, N);
std::vector<inc_t> stride_a(ndim_a, 1);
std::vector<inc_t> stride_b(ndim_b, 1);
std::vector<inc_t> stride_c(ndim_c, 1);
for ( gint_t i = 1; i < ndim_a; i++ )
stride_a[i] = stride_a[i-1] * len_a[i - 1];
for ( gint_t i = 1; i < ndim_b; i++ )
stride_b[i] = stride_b[i-1] * len_b[i - 1];
for ( gint_t i = 1; i < ndim_c; i++ )
stride_c[i] = stride_c[i-1] * len_c[i - 1];
std::vector<int> dim_a(ndim_a);
std::vector<int> dim_b(ndim_b);
std::vector<int> dim_c(ndim_c);
std::iota(dim_a.begin(), dim_a.end(), 0);
std::iota(dim_b.begin(), dim_b.end(), 0);
std::iota(dim_c.begin(), dim_c.end(), 0);
for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ )
do
do
do
{
auto dt = ( num_t )dt_;
auto ndim_m = (ndim_a + ndim_c - ndim_b)/2;
auto ndim_k = (ndim_a + ndim_b - ndim_c)/2;
std::vector<dim_t> m(len_a.begin(), len_a.begin()+ndim_m);
std::vector<dim_t> n(len_b.begin()+ndim_k, len_b.end());
std::vector<dim_t> k(len_b.begin(), len_b.begin()+ndim_k);
std::vector<inc_t> rs_a(stride_a.begin(), stride_a.begin()+ndim_m);
std::vector<inc_t> cs_a(stride_a.begin()+ndim_m, stride_a.end());
std::vector<inc_t> rs_b(stride_b.begin(), stride_b.begin()+ndim_k);
std::vector<inc_t> cs_b(stride_b.begin()+ndim_k, stride_b.end());
std::vector<inc_t> rs_c(stride_c.begin(), stride_c.begin()+ndim_m);
std::vector<inc_t> cs_c(stride_c.begin()+ndim_m, stride_c.end());
dim_t m_tot = 1;
dim_t n_tot = 1;
dim_t k_tot = 1;
for ( auto i : m ) m_tot *= i;
for ( auto i : n ) n_tot *= i;
for ( auto i : k ) k_tot *= i;
obj_t a, b, c, c_ref, norm;
bli_obj_create( dt, m_tot*k_tot, 1, 1, 1, &a );
bli_obj_create( dt, k_tot*n_tot, 1, 1, 1, &b );
bli_obj_create( dt, m_tot*n_tot, 1, 1, 1, &c );
bli_obj_create( dt, m_tot*n_tot, 1, 1, 1, &c_ref );
bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm );
bli_randv( &a );
bli_randv( &b );
bli_randv( &c );
bli_copyv( &c, &c_ref );
tcontract( dt, m, n, k,
bli_obj_buffer_for_const( dt, &BLIS_ONE ),
bli_obj_buffer( &a ), rs_a, cs_a,
bli_obj_buffer( &b ), rs_b, cs_b,
bli_obj_buffer_for_const( dt, &BLIS_ZERO ),
bli_obj_buffer( &c ), rs_c, cs_c );
tcontract_ref( dt, m, n, k,
bli_obj_buffer_for_const( dt, &BLIS_ONE ),
bli_obj_buffer( &a ), rs_a, cs_a,
bli_obj_buffer( &b ), rs_b, cs_b,
bli_obj_buffer_for_const( dt, &BLIS_ZERO ),
bli_obj_buffer( &c_ref ), rs_c, cs_c );
bli_subv( &c_ref, &c );
bli_normfv( &c, &norm );
double normr, normi;
bli_getsc( &norm, &normr, &normi );
printf("dt: %d, dim_a: [%d,%d,%d,%d], dim_b: [%d,%d,%d,%d], dim_c: [%d,%d,%d,%d], norm: %g\n",
dt, dim_a[0], dim_a[1], dim_a[2], dim_a[3],
dim_b[0], dim_b[1], dim_b[2], dim_b[3],
dim_c[0], dim_c[1], dim_c[2], dim_c[3],
normr / std::sqrt( bli_obj_vector_dim( &c ) ) );
bli_obj_free( &a );
bli_obj_free( &b );
bli_obj_free( &c );
bli_obj_free( &c_ref );
}
while (std::next_permutation(dim_a.begin(), dim_a.end()));
while (std::next_permutation(dim_b.begin(), dim_b.end()));
while (std::next_permutation(dim_c.begin(), dim_c.end()));
}

View File

@@ -0,0 +1,67 @@
#include "tcontract_ref.hpp"
template <typename T>
void tcontract_ref( const std::vector<dim_t>& m, const std::vector<dim_t>& n, const std::vector<dim_t>& k,
const void* alpha, const void* a, const std::vector<inc_t>& rs_a, const std::vector<inc_t>& cs_a,
const void* b, const std::vector<inc_t>& rs_b, const std::vector<inc_t>& cs_b,
const void* beta, void* c, const std::vector<inc_t>& rs_c, const std::vector<inc_t>& cs_c )
{
auto alpha_cast = *( T* )alpha;
auto beta_cast = *( T* )beta;
auto a_cast = ( T* )a;
auto b_cast = ( T* )b;
auto c_cast = ( T* )c;
for_each(m.size(), m.data(), a_cast, rs_a.data(), c_cast, rs_c.data(),
[&]
{
for_each(n.size(), n.data(), b_cast, cs_b.data(), c_cast, cs_c.data(),
[&]
{
auto ab = convert<T>(0.0);
for_each(k.size(), k.data(), a_cast, cs_a.data(), b_cast, rs_b.data(),
[&]
{
ab += (*a_cast) * (*b_cast);
});
if ( beta_cast == convert<T>(0.0) )
{
*c_cast = alpha_cast * ab;
}
else
{
*c_cast = alpha_cast * ab + beta_cast * (*c_cast);
}
});
assert(b_cast == b);
});
assert(a_cast == a);
assert(c_cast == c);
}
#undef GENTFUNC
#define GENTFUNC(ctype,ch,op) \
static auto PASTEMAC(ch,op) = &tcontract_ref<ctype>;
INSERT_GENTFUNC_BASIC0(tcontract_ref);
static decltype(&tcontract_ref<void>) GENARRAY( tcontract_ref_impl, tcontract_ref );
void tcontract_ref( num_t dt, const std::vector<dim_t>& m, const std::vector<dim_t>& n, const std::vector<dim_t>& k,
const void* alpha, const void* a, const std::vector<inc_t>& rs_a, const std::vector<inc_t>& cs_a,
const void* b, const std::vector<inc_t>& rs_b, const std::vector<inc_t>& cs_b,
const void* beta, void* c, const std::vector<inc_t>& rs_c, const std::vector<inc_t>& cs_c )
{
tcontract_ref_impl[ dt ]
(
m, n, k,
alpha, a, rs_a, cs_a,
b, rs_b, cs_b,
beta, c, rs_c, cs_c
);
}

View File

@@ -0,0 +1,100 @@
#include "blis.h"
#include "complex_math.hpp"
#include <vector>
#include <array>
#include <cassert>
inline void increment(inc_t, gint_t) {}
template <typename T, typename... Args>
void increment(inc_t n, gint_t i, T& off, const inc_t* s, Args&... args)
{
off += s[i]*n;
increment(n, i, args...);
}
template <typename Body, typename... Args>
void for_each_impl(gint_t ndim, const dim_t* n,
dim_t off, dim_t len,
Body& body,
Args&... args)
{
std::array<dim_t,8> i = {};
assert( ndim <= i.size() );
if ( off )
{
for ( gint_t k = 0; k < ndim; k++ )
{
i[k] = off % n[k];
off /= n[k];
increment(i[k], k, args...);
}
}
for ( dim_t pos = 0; pos < len; pos++ )
{
body();
for ( gint_t k = 0; k < ndim; k++ )
{
if ( i[k] == n[k]-1 )
{
increment(-i[k], k, args...);
i[k] = 0;
}
else
{
increment(1, k, args...);
i[k]++;
break;
}
}
}
}
template <typename T, typename Body>
void for_each(gint_t ndim, const dim_t* n,
dim_t off, dim_t len,
T& a, const inc_t* s_a,
Body&& body)
{
for_each_impl( ndim, n, off, len, body, a, s_a );
}
template <typename T, typename Body>
void for_each(gint_t ndim, const dim_t* n,
dim_t off, dim_t len,
T& a, const inc_t* s_a,
T& b, const inc_t* s_b,
Body&& body)
{
for_each_impl( ndim, n, off, len, body, a, s_a, b, s_b );
}
template <typename T, typename Body>
void for_each(gint_t ndim, const dim_t* n,
T& a, const inc_t* s_a,
Body&& body)
{
dim_t len = 1;
for ( gint_t i = 0;i < ndim;i++ ) len *= n[i];
for_each_impl( ndim, n, 0, len, body, a, s_a );
}
template <typename T, typename Body>
void for_each(gint_t ndim, const dim_t* n,
T& a, const inc_t* s_a,
T& b, const inc_t* s_b,
Body&& body)
{
dim_t len = 1;
for ( gint_t i = 0;i < ndim;i++ ) len *= n[i];
for_each_impl( ndim, n, 0, len, body, a, s_a, b, s_b );
}
void tcontract_ref( num_t dt, const std::vector<dim_t>& m, const std::vector<dim_t>& n, const std::vector<dim_t>& k,
const void* alpha, const void* a, const std::vector<inc_t>& rs_a, const std::vector<inc_t>& cs_a,
const void* b, const std::vector<inc_t>& rs_b, const std::vector<inc_t>& cs_b,
const void* beta, void* c, const std::vector<inc_t>& rs_c, const std::vector<inc_t>& cs_c );