mirror of
https://github.com/amd/blis.git
synced 2026-05-13 02:25:39 +00:00
Move edge cases to gemm ukr; more user-custom mods. (#583)
Details: - Moved edge-case handling into the gemm microkernel. This required changing the microkernel API to take m and n dimension parameters. This required updating all existing gemm microkernel function pointer types, function signatures, and related definitions to take m and n dimensions. We also updated all existing kernels in the 'kernels' directory to take m and n dimensions, and implemented edge-case handling within those microkernels via a collection of new C preprocessor macros defined within bli_edge_case_macro_defs.h. Also removed the assembly code that formerly would handle general stride IO on the microtile, since this can now be handled by the same code that does edge cases. - Pass the obj_t.ker_fn (of matrix C) into bli_gemm_cntl_create() and bli_trsm_cntl_create(), where this function pointer is used in lieu of the default macrokernel when it is non-NULL, and ignored when it is NULL. - Re-implemented macrokernel in bli_gemm_ker_var2.c to be a single function using byte pointers rather that one function for each floating-point datatype. Also, obtain the microkernel function pointer from the .ukr field of the params struct embedded within the obj_t for matrix C (assuming params is non-NULL and contains a non-NULL value in the .ukr field). Communicate both the gemm microkernel pointer to use as well as the params struct to the microkernel via the auxinfo_t struct. - Defined gemm_ker_params_t type (for the aforementioned obj_t.params struct) in bli_gemm_var.h. - Retired the separate _md macrokernel for mixed datatype computation. We now use the reimplemented bli_gemm_ker_var2() instead. - Updated gemmt macrokernels to pass m and n dimensions into microkernel calls. - Removed edge-case handling from trmm and trsm macrokernels. - Moved most of bli_packm_alloc() code into a new helper function, bli_packm_alloc_ex(). - Fixed a typo bug in bli_gemmtrsm_u_template_noopt_mxn.c. - Added test/syrk_diagonal and test/tensor_contraction directories with associated code to test those operations.
This commit is contained in:
@@ -37,6 +37,8 @@
|
||||
|
||||
void bli_zgemm_template_noopt
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
dcomplex* restrict alpha,
|
||||
dcomplex* restrict a1,
|
||||
@@ -88,8 +90,7 @@ void bli_zgemm_template_noopt
|
||||
|
||||
dim_t l, j, i;
|
||||
|
||||
dcomplex ab[ bli_zmr *
|
||||
bli_znr ];
|
||||
dcomplex ab[ mr * nr ];
|
||||
dcomplex* abij;
|
||||
dcomplex ai, bj;
|
||||
|
||||
@@ -137,16 +138,16 @@ void bli_zgemm_template_noopt
|
||||
if ( bli_zeq0( *beta ) )
|
||||
{
|
||||
/* c11 := ab */
|
||||
bli_zcopys_mxn( mr,
|
||||
nr,
|
||||
bli_zcopys_mxn( m,
|
||||
n,
|
||||
ab, rs_ab, cs_ab,
|
||||
c11, rs_c, cs_c );
|
||||
}
|
||||
else
|
||||
{
|
||||
/* c11 := beta * c11 + ab */
|
||||
bli_zxpbys_mxn( mr,
|
||||
nr,
|
||||
bli_zxpbys_mxn( m,
|
||||
n,
|
||||
ab, rs_ab, cs_ab,
|
||||
beta,
|
||||
c11, rs_c, cs_c );
|
||||
|
||||
@@ -74,6 +74,8 @@ void bli_zgemmtrsm_l_template_noopt
|
||||
*/
|
||||
const num_t dt = BLIS_DCOMPLEX;
|
||||
|
||||
const inc_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx );
|
||||
const inc_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx );
|
||||
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx );
|
||||
|
||||
const inc_t rs_b = packnr;
|
||||
@@ -84,6 +86,8 @@ void bli_zgemmtrsm_l_template_noopt
|
||||
/* b11 = alpha * b11 - a10 * b01; */
|
||||
bli_zgemm_template_noopt
|
||||
(
|
||||
mr,
|
||||
nr,
|
||||
k,
|
||||
minus_one,
|
||||
a10,
|
||||
|
||||
@@ -74,6 +74,8 @@ void bli_zgemmtrsm_u_template_noopt
|
||||
*/
|
||||
const num_t dt = BLIS_DCOMPLEX;
|
||||
|
||||
const inc_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx );
|
||||
const inc_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx );
|
||||
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx );
|
||||
|
||||
const inc_t rs_b = packnr;
|
||||
@@ -84,10 +86,12 @@ void bli_zgemmtrsm_u_template_noopt
|
||||
/* b11 = alpha * b11 - a12 * b21; */
|
||||
bli_zgemm_template_noopt
|
||||
(
|
||||
mr,
|
||||
nr,
|
||||
k,
|
||||
minus_one,
|
||||
a12,
|
||||
b21,
|
||||
a10,
|
||||
b01,
|
||||
alpha,
|
||||
b11, rs_b, cs_b,
|
||||
data
|
||||
|
||||
@@ -36,16 +36,35 @@
|
||||
#include "blis.h"
|
||||
|
||||
void* bli_packm_alloc
|
||||
(
|
||||
siz_t size_needed,
|
||||
rntm_t* rntm,
|
||||
cntl_t* cntl,
|
||||
thrinfo_t* thread
|
||||
)
|
||||
(
|
||||
siz_t size_needed,
|
||||
rntm_t* rntm,
|
||||
cntl_t* cntl,
|
||||
thrinfo_t* thread
|
||||
)
|
||||
{
|
||||
// Query the pack buffer type from the control tree node.
|
||||
packbuf_t pack_buf_type = bli_cntl_packm_params_pack_buf_type( cntl );
|
||||
|
||||
return bli_packm_alloc_ex
|
||||
(
|
||||
size_needed,
|
||||
pack_buf_type,
|
||||
rntm,
|
||||
cntl,
|
||||
thread
|
||||
);
|
||||
}
|
||||
|
||||
void* bli_packm_alloc_ex
|
||||
(
|
||||
siz_t size_needed,
|
||||
packbuf_t pack_buf_type,
|
||||
rntm_t* rntm,
|
||||
cntl_t* cntl,
|
||||
thrinfo_t* thread
|
||||
)
|
||||
{
|
||||
// Query the address of the mem_t entry within the control tree node.
|
||||
mem_t* cntl_mem_p = bli_cntl_pack_mem( cntl );
|
||||
|
||||
@@ -55,7 +74,7 @@ void* bli_packm_alloc
|
||||
siz_t cntl_mem_size = 0;
|
||||
|
||||
if ( bli_mem_is_alloc( cntl_mem_p ) )
|
||||
cntl_mem_size = bli_mem_size( cntl_mem_p );
|
||||
cntl_mem_size = bli_mem_size( cntl_mem_p );
|
||||
|
||||
if ( cntl_mem_size < size_needed )
|
||||
{
|
||||
@@ -64,14 +83,15 @@ void* bli_packm_alloc
|
||||
// The chief thread releases the existing block associated with
|
||||
// the mem_t entry in the control tree, and then re-acquires a
|
||||
// new block, saving the associated mem_t entry to local_mem_s.
|
||||
if ( bli_mem_is_alloc( cntl_mem_p ) )
|
||||
{
|
||||
bli_pba_release
|
||||
(
|
||||
rntm,
|
||||
cntl_mem_p
|
||||
);
|
||||
}
|
||||
if ( bli_mem_is_alloc( cntl_mem_p ) )
|
||||
{
|
||||
bli_pba_release
|
||||
(
|
||||
rntm,
|
||||
cntl_mem_p
|
||||
);
|
||||
}
|
||||
|
||||
bli_pba_acquire_m
|
||||
(
|
||||
rntm,
|
||||
@@ -89,11 +109,11 @@ void* bli_packm_alloc
|
||||
// this thread's control tree node.
|
||||
*cntl_mem_p = *local_mem_p;
|
||||
|
||||
// Barrier so that the master thread doesn't return from the function
|
||||
// before we are done reading.
|
||||
bli_thread_barrier( thread );
|
||||
// Barrier so that the master thread doesn't return from the function
|
||||
// before we are done reading.
|
||||
bli_thread_barrier( thread );
|
||||
}
|
||||
|
||||
return bli_mem_buffer( cntl_mem_p );
|
||||
return bli_mem_buffer( cntl_mem_p );
|
||||
}
|
||||
|
||||
|
||||
@@ -32,11 +32,20 @@
|
||||
|
||||
*/
|
||||
|
||||
BLIS_EXPORT_BLIS void* bli_packm_alloc
|
||||
(
|
||||
siz_t size_needed,
|
||||
rntm_t* rntm,
|
||||
cntl_t* cntl,
|
||||
thrinfo_t* thread
|
||||
);
|
||||
BLIS_EXPORT_BLIS void* bli_packm_alloc
|
||||
(
|
||||
siz_t size_needed,
|
||||
rntm_t* rntm,
|
||||
cntl_t* cntl,
|
||||
thrinfo_t* thread
|
||||
);
|
||||
|
||||
BLIS_EXPORT_BLIS void* bli_packm_alloc_ex
|
||||
(
|
||||
siz_t size_needed,
|
||||
packbuf_t pack_buf_type,
|
||||
rntm_t* rntm,
|
||||
cntl_t* cntl,
|
||||
thrinfo_t* thread
|
||||
);
|
||||
|
||||
|
||||
@@ -57,7 +57,14 @@ void bli_l3_cntl_create_if
|
||||
family == BLIS_GEMMT ||
|
||||
family == BLIS_TRMM )
|
||||
{
|
||||
*cntl_use = bli_gemm_cntl_create( rntm, family, schema_a, schema_b );
|
||||
*cntl_use = bli_gemm_cntl_create
|
||||
(
|
||||
rntm,
|
||||
family,
|
||||
schema_a,
|
||||
schema_b,
|
||||
bli_obj_ker_fn( c )
|
||||
);
|
||||
}
|
||||
else // if ( family == BLIS_TRSM )
|
||||
{
|
||||
@@ -66,7 +73,14 @@ void bli_l3_cntl_create_if
|
||||
if ( bli_obj_is_triangular( a ) ) side = BLIS_LEFT;
|
||||
else side = BLIS_RIGHT;
|
||||
|
||||
*cntl_use = bli_trsm_cntl_create( rntm, side, schema_a, schema_b );
|
||||
*cntl_use = bli_trsm_cntl_create
|
||||
(
|
||||
rntm,
|
||||
side,
|
||||
schema_a,
|
||||
schema_b,
|
||||
bli_obj_ker_fn( c )
|
||||
);
|
||||
}
|
||||
}
|
||||
else
|
||||
|
||||
@@ -47,6 +47,8 @@
|
||||
\
|
||||
typedef void (*PASTECH3(ch,opname,_ukr,tsuf)) \
|
||||
( \
|
||||
dim_t m, \
|
||||
dim_t n, \
|
||||
dim_t k, \
|
||||
ctype* restrict alpha, \
|
||||
ctype* restrict a, \
|
||||
|
||||
@@ -51,6 +51,8 @@ void PASTEMAC0(opname) \
|
||||
\
|
||||
num_t dt = bli_obj_dt( c ); \
|
||||
\
|
||||
dim_t m = bli_obj_length( c ); \
|
||||
dim_t n = bli_obj_width( c ); \
|
||||
dim_t k = bli_obj_width( a ); \
|
||||
void* buf_a = bli_obj_buffer_at_off( a ); \
|
||||
void* buf_b = bli_obj_buffer_at_off( b ); \
|
||||
@@ -75,6 +77,8 @@ void PASTEMAC0(opname) \
|
||||
\
|
||||
f \
|
||||
( \
|
||||
m, \
|
||||
n, \
|
||||
k, \
|
||||
buf_alpha, \
|
||||
buf_a, \
|
||||
|
||||
@@ -42,6 +42,8 @@
|
||||
\
|
||||
void PASTEMAC(ch,opname) \
|
||||
( \
|
||||
dim_t m, \
|
||||
dim_t n, \
|
||||
dim_t k, \
|
||||
ctype_out* restrict alpha, \
|
||||
ctype_in* restrict a, \
|
||||
|
||||
@@ -39,6 +39,8 @@
|
||||
\
|
||||
void PASTEMAC(ch,opname) \
|
||||
( \
|
||||
dim_t m, \
|
||||
dim_t n, \
|
||||
dim_t k, \
|
||||
ctype* restrict alpha, \
|
||||
ctype* restrict a, \
|
||||
@@ -58,16 +60,19 @@ void PASTEMAC(ch,opname) \
|
||||
PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \
|
||||
\
|
||||
/* Invoke the typed function for the given datatype. */ \
|
||||
f( \
|
||||
k, \
|
||||
alpha, \
|
||||
a, \
|
||||
b, \
|
||||
beta, \
|
||||
c, rs_c, cs_c, \
|
||||
data, \
|
||||
cntx \
|
||||
); \
|
||||
f \
|
||||
( \
|
||||
m, \
|
||||
n, \
|
||||
k, \
|
||||
alpha, \
|
||||
a, \
|
||||
b, \
|
||||
beta, \
|
||||
c, rs_c, cs_c, \
|
||||
data, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
|
||||
INSERT_GENTFUNC_BASIC2( gemm_ukernel, gemm, BLIS_GEMM_UKR )
|
||||
@@ -98,17 +103,18 @@ void PASTEMAC(ch,opname) \
|
||||
PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \
|
||||
\
|
||||
/* Invoke the typed function for the given datatype. */ \
|
||||
f( \
|
||||
k, \
|
||||
alpha, \
|
||||
a1x, \
|
||||
a11, \
|
||||
bx1, \
|
||||
b11, \
|
||||
c11, rs_c, cs_c, \
|
||||
data, \
|
||||
cntx \
|
||||
); \
|
||||
f \
|
||||
( \
|
||||
k, \
|
||||
alpha, \
|
||||
a1x, \
|
||||
a11, \
|
||||
bx1, \
|
||||
b11, \
|
||||
c11, rs_c, cs_c, \
|
||||
data, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
|
||||
INSERT_GENTFUNC_BASIC2( gemmtrsm_l_ukernel, gemmtrsm, BLIS_GEMMTRSM_L_UKR )
|
||||
@@ -136,13 +142,14 @@ void PASTEMAC(ch,opname) \
|
||||
PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \
|
||||
\
|
||||
/* Invoke the typed function for the given datatype. */ \
|
||||
f( \
|
||||
a, \
|
||||
b, \
|
||||
c, rs_c, cs_c, \
|
||||
data, \
|
||||
cntx \
|
||||
); \
|
||||
f \
|
||||
( \
|
||||
a, \
|
||||
b, \
|
||||
c, rs_c, cs_c, \
|
||||
data, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
|
||||
INSERT_GENTFUNC_BASIC2( trsm_l_ukernel, trsm, BLIS_TRSM_L_UKR )
|
||||
|
||||
@@ -40,10 +40,11 @@ cntl_t* bli_gemm_cntl_create
|
||||
rntm_t* rntm,
|
||||
opid_t family,
|
||||
pack_t schema_a,
|
||||
pack_t schema_b
|
||||
pack_t schema_b,
|
||||
void_fp ker
|
||||
)
|
||||
{
|
||||
return bli_gemmbp_cntl_create( rntm, family, schema_a, schema_b );
|
||||
return bli_gemmbp_cntl_create( rntm, family, schema_a, schema_b, ker );
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
@@ -53,18 +54,22 @@ cntl_t* bli_gemmbp_cntl_create
|
||||
rntm_t* rntm,
|
||||
opid_t family,
|
||||
pack_t schema_a,
|
||||
pack_t schema_b
|
||||
pack_t schema_b,
|
||||
void_fp ker
|
||||
)
|
||||
{
|
||||
void_fp macro_kernel_fp;
|
||||
|
||||
// Use the function pointers to the macrokernels that use slab
|
||||
// assignment of micropanels to threads in the jr and ir loops.
|
||||
// Choose the default macrokernel based on the operation family...
|
||||
if ( family == BLIS_GEMM ) macro_kernel_fp = bli_gemm_ker_var2;
|
||||
else if ( family == BLIS_GEMMT ) macro_kernel_fp = bli_gemmt_x_ker_var2;
|
||||
else if ( family == BLIS_TRMM ) macro_kernel_fp = bli_trmm_xx_ker_var2;
|
||||
else /* should never execute */ macro_kernel_fp = NULL;
|
||||
|
||||
// ...unless a non-NULL kernel function pointer is passed in, in which
|
||||
// case we use that instead.
|
||||
if ( ker ) macro_kernel_fp = ker;
|
||||
|
||||
// Create two nodes for the macro-kernel.
|
||||
cntl_t* gemm_cntl_bu_ke = bli_gemm_cntl_create_node
|
||||
(
|
||||
|
||||
@@ -38,7 +38,8 @@ cntl_t* bli_gemm_cntl_create
|
||||
rntm_t* rntm,
|
||||
opid_t family,
|
||||
pack_t schema_a,
|
||||
pack_t schema_b
|
||||
pack_t schema_b,
|
||||
void_fp ker
|
||||
);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
@@ -48,7 +49,8 @@ cntl_t* bli_gemmbp_cntl_create
|
||||
rntm_t* rntm,
|
||||
opid_t family,
|
||||
pack_t schema_a,
|
||||
pack_t schema_b
|
||||
pack_t schema_b,
|
||||
void_fp ker
|
||||
);
|
||||
|
||||
#if 0
|
||||
|
||||
@@ -283,90 +283,3 @@ void bli_gemm_front
|
||||
#endif
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
#if 0
|
||||
if ( bli_obj_dt( a ) != bli_obj_dt( b ) ||
|
||||
bli_obj_dt( a ) != bli_obj_dt( c ) ||
|
||||
bli_obj_comp_prec( c ) != bli_obj_prec( c ) )
|
||||
{
|
||||
const bool a_is_real = bli_obj_is_real( a );
|
||||
const bool a_is_comp = bli_obj_is_complex( a );
|
||||
const bool b_is_real = bli_obj_is_real( b );
|
||||
const bool b_is_comp = bli_obj_is_complex( b );
|
||||
const bool c_is_real = bli_obj_is_real( c );
|
||||
const bool c_is_comp = bli_obj_is_complex( c );
|
||||
|
||||
const bool a_is_single = bli_obj_is_single_prec( a );
|
||||
const bool a_is_double = bli_obj_is_double_prec( a );
|
||||
const bool b_is_single = bli_obj_is_single_prec( b );
|
||||
const bool b_is_double = bli_obj_is_double_prec( b );
|
||||
const bool c_is_single = bli_obj_is_single_prec( c );
|
||||
const bool c_is_double = bli_obj_is_double_prec( c );
|
||||
|
||||
const bool comp_single = bli_obj_comp_prec( c ) == BLIS_SINGLE_PREC;
|
||||
const bool comp_double = bli_obj_comp_prec( c ) == BLIS_DOUBLE_PREC;
|
||||
|
||||
const bool mixeddomain = bli_obj_domain( c ) != bli_obj_domain( a ) ||
|
||||
bli_obj_domain( c ) != bli_obj_domain( b );
|
||||
|
||||
( void )a_is_real; ( void )a_is_comp;
|
||||
( void )b_is_real; ( void )b_is_comp;
|
||||
( void )c_is_real; ( void )c_is_comp;
|
||||
( void )a_is_single; ( void )a_is_double;
|
||||
( void )b_is_single; ( void )b_is_double;
|
||||
( void )c_is_single; ( void )c_is_double;
|
||||
( void )comp_single; ( void )comp_double;
|
||||
|
||||
if (
|
||||
//( c_is_comp && a_is_comp && b_is_real ) ||
|
||||
//( c_is_comp && a_is_real && b_is_comp ) ||
|
||||
//( c_is_real && a_is_comp && b_is_comp ) ||
|
||||
//( c_is_comp && a_is_real && b_is_real ) ||
|
||||
//( c_is_real && a_is_comp && b_is_real ) ||
|
||||
//( c_is_real && a_is_real && b_is_comp ) ||
|
||||
//FALSE
|
||||
TRUE
|
||||
)
|
||||
{
|
||||
if (
|
||||
( c_is_single && a_is_single && b_is_single && mixeddomain ) ||
|
||||
( c_is_single && a_is_single && b_is_single && comp_single ) ||
|
||||
( c_is_single && a_is_single && b_is_single && comp_double ) ||
|
||||
( c_is_single && a_is_single && b_is_double ) ||
|
||||
( c_is_single && a_is_double && b_is_single ) ||
|
||||
( c_is_double && a_is_single && b_is_single ) ||
|
||||
( c_is_single && a_is_double && b_is_double ) ||
|
||||
( c_is_double && a_is_single && b_is_double ) ||
|
||||
( c_is_double && a_is_double && b_is_single ) ||
|
||||
( c_is_double && a_is_double && b_is_double && comp_single ) ||
|
||||
( c_is_double && a_is_double && b_is_double && comp_double ) ||
|
||||
( c_is_double && a_is_double && b_is_double && mixeddomain ) ||
|
||||
FALSE
|
||||
)
|
||||
bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl );
|
||||
else
|
||||
bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl );
|
||||
}
|
||||
else
|
||||
bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl );
|
||||
return;
|
||||
}
|
||||
#else
|
||||
#if 0
|
||||
// If any of the storage datatypes differ, or if the execution precision
|
||||
// differs from the storage precision of C, utilize the mixed datatype
|
||||
// code path.
|
||||
// NOTE: We could check the exec dt against the storage dt of C, but for
|
||||
// now we don't support the caller setting the execution domain
|
||||
// explicitly.
|
||||
if ( bli_obj_dt( a ) != bli_obj_dt( b ) ||
|
||||
bli_obj_dt( a ) != bli_obj_dt( c ) ||
|
||||
bli_obj_comp_prec( c ) != bli_obj_prec( c ) )
|
||||
{
|
||||
bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl );
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
||||
@@ -35,28 +35,44 @@
|
||||
|
||||
#include "blis.h"
|
||||
|
||||
#define FUNCPTR_T gemm_fp
|
||||
typedef void (*xpbys_mxn_vft)
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
void* x, inc_t rs_x, inc_t cs_x,
|
||||
void* b,
|
||||
void* y, inc_t rs_y, inc_t cs_y
|
||||
);
|
||||
|
||||
typedef void (*FUNCPTR_T)
|
||||
(
|
||||
pack_t schema_a,
|
||||
pack_t schema_b,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
void* alpha,
|
||||
void* a, inc_t cs_a, inc_t is_a,
|
||||
dim_t pd_a, inc_t ps_a,
|
||||
void* b, inc_t rs_b, inc_t is_b,
|
||||
dim_t pd_b, inc_t ps_b,
|
||||
void* beta,
|
||||
void* c, inc_t rs_c, inc_t cs_c,
|
||||
cntx_t* cntx,
|
||||
rntm_t* rntm,
|
||||
thrinfo_t* thread
|
||||
);
|
||||
#undef GENTFUNC2
|
||||
#define GENTFUNC2(ctypex,ctypey,chx,chy,op) \
|
||||
\
|
||||
void PASTEMAC2(chx,chy,op) \
|
||||
( \
|
||||
dim_t m, \
|
||||
dim_t n, \
|
||||
void* x, inc_t rs_x, inc_t cs_x, \
|
||||
void* b, \
|
||||
void* y, inc_t rs_y, inc_t cs_y \
|
||||
) \
|
||||
{ \
|
||||
ctypex* restrict x_cast = x; \
|
||||
ctypey* restrict b_cast = b; \
|
||||
ctypey* restrict y_cast = y; \
|
||||
\
|
||||
PASTEMAC3(chx,chy,chy,xpbys_mxn) \
|
||||
( \
|
||||
m, n, \
|
||||
x_cast, rs_x, cs_x, \
|
||||
b_cast, \
|
||||
y_cast, rs_y, cs_y \
|
||||
); \
|
||||
}
|
||||
|
||||
static FUNCPTR_T GENARRAY(ftypes,gemm_ker_var2);
|
||||
INSERT_GENTFUNC2_BASIC0(xbpys_mxn_fn);
|
||||
INSERT_GENTFUNC2_MIXDP0(xbpys_mxn_fn);
|
||||
|
||||
static xpbys_mxn_vft GENARRAY2_ALL(xbpys_mxn, xbpys_mxn_fn);
|
||||
|
||||
|
||||
void bli_gemm_ker_var2
|
||||
@@ -70,23 +86,8 @@ void bli_gemm_ker_var2
|
||||
thrinfo_t* thread
|
||||
)
|
||||
{
|
||||
#ifdef BLIS_ENABLE_GEMM_MD
|
||||
// By now, A and B have been packed and cast to the execution precision.
|
||||
// In most cases, such as when storage precision of C differs from the
|
||||
// execution precision, we utilize the mixed datatype code path. However,
|
||||
// a few cases still fall within this kernel, such as mixed domain with
|
||||
// equal precision (ccr, crc, rcc), hence those expressions being disabled
|
||||
// in the conditional below.
|
||||
if ( //( bli_obj_domain( c ) != bli_obj_domain( a ) ) ||
|
||||
//( bli_obj_domain( c ) != bli_obj_domain( b ) ) ||
|
||||
( bli_obj_dt( c ) != bli_obj_exec_dt( c ) ) )
|
||||
{
|
||||
bli_gemm_ker_var2_md( a, b, c, cntx, rntm, cntl, thread );
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
num_t dt_exec = bli_obj_exec_dt( c );
|
||||
num_t dt_c = bli_obj_dt( c );
|
||||
|
||||
pack_t schema_a = bli_obj_pack_schema( a );
|
||||
pack_t schema_b = bli_obj_pack_schema( b );
|
||||
@@ -95,50 +96,55 @@ void bli_gemm_ker_var2
|
||||
dim_t n = bli_obj_width( c );
|
||||
dim_t k = bli_obj_width( a );
|
||||
|
||||
void* buf_a = bli_obj_buffer_at_off( a );
|
||||
inc_t cs_a = bli_obj_col_stride( a );
|
||||
char* a_cast = bli_obj_buffer_at_off( a );
|
||||
inc_t is_a = bli_obj_imag_stride( a );
|
||||
dim_t pd_a = bli_obj_panel_dim( a );
|
||||
inc_t ps_a = bli_obj_panel_stride( a );
|
||||
|
||||
void* buf_b = bli_obj_buffer_at_off( b );
|
||||
inc_t rs_b = bli_obj_row_stride( b );
|
||||
char* b_cast = bli_obj_buffer_at_off( b );
|
||||
inc_t is_b = bli_obj_imag_stride( b );
|
||||
dim_t pd_b = bli_obj_panel_dim( b );
|
||||
inc_t ps_b = bli_obj_panel_stride( b );
|
||||
|
||||
void* buf_c = bli_obj_buffer_at_off( c );
|
||||
char* c_cast = bli_obj_buffer_at_off( c );
|
||||
inc_t rs_c = bli_obj_row_stride( c );
|
||||
inc_t cs_c = bli_obj_col_stride( c );
|
||||
|
||||
obj_t scalar_a;
|
||||
obj_t scalar_b;
|
||||
|
||||
void* buf_alpha;
|
||||
void* buf_beta;
|
||||
|
||||
FUNCPTR_T f;
|
||||
// If any dimension is zero, return immediately.
|
||||
if ( bli_zero_dim3( m, n, k ) ) return;
|
||||
|
||||
// Detach and multiply the scalars attached to A and B.
|
||||
// NOTE: We know that the internal scalars of A and B are already of the
|
||||
// target datatypes because the necessary typecasting would have already
|
||||
// taken place during bli_packm_init().
|
||||
obj_t scalar_a;
|
||||
obj_t scalar_b;
|
||||
bli_obj_scalar_detach( a, &scalar_a );
|
||||
bli_obj_scalar_detach( b, &scalar_b );
|
||||
bli_mulsc( &scalar_a, &scalar_b );
|
||||
|
||||
// Grab the addresses of the internal scalar buffers for the scalar
|
||||
// merged above and the scalar attached to C.
|
||||
buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b );
|
||||
buf_beta = bli_obj_internal_scalar_buffer( c );
|
||||
// NOTE: We know that scalar_b is of type dt_exec due to the above code
|
||||
// that casts the scalars of A and B to dt_exec via scalar_a and scalar_b,
|
||||
// and we know that the internal scalar in C is already of the type dt_c
|
||||
// due to the casting in the implementation of bli_obj_scalar_attach().
|
||||
char* alpha_cast = bli_obj_internal_scalar_buffer( &scalar_b );
|
||||
char* beta_cast = bli_obj_internal_scalar_buffer( c );
|
||||
|
||||
// If 1m is being employed on a column- or row-stored matrix with a
|
||||
// real-valued beta, we can use the real domain macro-kernel, which
|
||||
// eliminates a little overhead associated with the 1m virtual
|
||||
// micro-kernel.
|
||||
// Only employ this optimization if the storage datatype of C is
|
||||
// equal to the execution/computation datatype.
|
||||
#if 1
|
||||
if ( bli_cntx_method( cntx ) == BLIS_1M )
|
||||
{
|
||||
bli_gemm_ind_recast_1m_params
|
||||
(
|
||||
&dt_exec,
|
||||
&dt_c,
|
||||
schema_a,
|
||||
c,
|
||||
&m, &n, &k,
|
||||
@@ -151,273 +157,211 @@ void bli_gemm_ker_var2
|
||||
|
||||
#ifdef BLIS_ENABLE_GEMM_MD
|
||||
// Tweak parameters in select mixed domain cases (rcc, crc, ccr).
|
||||
bli_gemm_md_ker_var2_recast
|
||||
(
|
||||
&dt_exec,
|
||||
bli_obj_dt( a ),
|
||||
bli_obj_dt( b ),
|
||||
bli_obj_dt( c ),
|
||||
&m, &n, &k,
|
||||
&pd_a, &ps_a,
|
||||
&pd_b, &ps_b,
|
||||
c,
|
||||
&rs_c, &cs_c
|
||||
);
|
||||
if ( bli_cntx_method( cntx ) == BLIS_NAT )
|
||||
{
|
||||
bli_gemm_md_ker_var2_recast
|
||||
(
|
||||
&dt_exec,
|
||||
bli_obj_dt( a ),
|
||||
bli_obj_dt( b ),
|
||||
&dt_c,
|
||||
&m, &n, &k,
|
||||
&pd_a, &ps_a,
|
||||
&pd_b, &ps_b,
|
||||
c,
|
||||
&rs_c, &cs_c
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Index into the type combination array to extract the correct
|
||||
// function pointer.
|
||||
f = ftypes[dt_exec];
|
||||
siz_t dt_size = bli_dt_size( dt_exec );
|
||||
siz_t dt_c_size = bli_dt_size( dt_c );
|
||||
|
||||
// Invoke the function.
|
||||
f( schema_a,
|
||||
schema_b,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
buf_alpha,
|
||||
buf_a, cs_a, is_a,
|
||||
pd_a, ps_a,
|
||||
buf_b, rs_b, is_b,
|
||||
pd_b, ps_b,
|
||||
buf_beta,
|
||||
buf_c, rs_c, cs_c,
|
||||
cntx,
|
||||
rntm,
|
||||
thread );
|
||||
}
|
||||
// Alias some constants to simpler names.
|
||||
const dim_t MR = pd_a;
|
||||
const dim_t NR = pd_b;
|
||||
//const dim_t PACKMR = cs_a;
|
||||
//const dim_t PACKNR = rs_b;
|
||||
|
||||
// Query the context for the micro-kernel address and cast it to its
|
||||
// function pointer type.
|
||||
gemm_ukr_vft gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt_exec, BLIS_GEMM_UKR, cntx );
|
||||
|
||||
// Query the params field from the obj_t. If it is non-NULL, grab the ukr
|
||||
// field of the params struct. If that function pointer is non-NULL, use it
|
||||
// as our microkernel instead of the default microkernel queried from the
|
||||
// cntx above.
|
||||
gemm_ker_params_t* params = bli_obj_ker_params( c );
|
||||
gemm_ukr_vft user_ukr = params ? params->ukr : NULL;
|
||||
if ( user_ukr ) gemm_ukr = user_ukr;
|
||||
|
||||
// Temporary C buffer for edge cases. Note that the strides of this
|
||||
// temporary buffer are set so that they match the storage of the
|
||||
// original C matrix. For example, if C is column-stored, ct will be
|
||||
// column-stored as well.
|
||||
char ct[ BLIS_STACK_BUF_MAX_SIZE ]
|
||||
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE)));
|
||||
const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt_exec, BLIS_GEMM_UKR, cntx );
|
||||
const inc_t rs_ct = ( col_pref ? 1 : NR );
|
||||
const inc_t cs_ct = ( col_pref ? MR : 1 );
|
||||
char* zero = bli_obj_buffer_for_const( dt_exec, &BLIS_ZERO );
|
||||
|
||||
//
|
||||
// Assumptions/assertions:
|
||||
// rs_a == 1
|
||||
// cs_a == PACKMR
|
||||
// pd_a == MR
|
||||
// ps_a == stride to next micro-panel of A
|
||||
// rs_b == PACKNR
|
||||
// cs_b == 1
|
||||
// pd_b == NR
|
||||
// ps_b == stride to next micro-panel of B
|
||||
// rs_c == (no assumptions)
|
||||
// cs_c == (no assumptions)
|
||||
//
|
||||
|
||||
// Compute number of primary and leftover components of the m and n
|
||||
// dimensions.
|
||||
dim_t n_iter = n / NR;
|
||||
dim_t n_left = n % NR;
|
||||
|
||||
dim_t m_iter = m / MR;
|
||||
dim_t m_left = m % MR;
|
||||
|
||||
if ( n_left ) ++n_iter;
|
||||
if ( m_left ) ++m_iter;
|
||||
|
||||
// Determine some increments used to step through A, B, and C.
|
||||
inc_t rstep_a = ps_a * dt_size;
|
||||
|
||||
inc_t cstep_b = ps_b * dt_size;
|
||||
|
||||
inc_t rstep_c = rs_c * MR * dt_c_size;
|
||||
inc_t cstep_c = cs_c * NR * dt_c_size;
|
||||
|
||||
auxinfo_t aux;
|
||||
|
||||
// Save the pack schemas of A and B to the auxinfo_t object.
|
||||
bli_auxinfo_set_schema_a( schema_a, &aux );
|
||||
bli_auxinfo_set_schema_b( schema_b, &aux );
|
||||
|
||||
// Save the imaginary stride of A and B to the auxinfo_t object.
|
||||
bli_auxinfo_set_is_a( is_a, &aux );
|
||||
bli_auxinfo_set_is_b( is_b, &aux );
|
||||
|
||||
// Save the virtual microkernel address and the params.
|
||||
bli_auxinfo_set_ukr( gemm_ukr, &aux );
|
||||
bli_auxinfo_set_params( params, &aux );
|
||||
|
||||
// The 'thread' argument points to the thrinfo_t node for the 2nd (jr)
|
||||
// loop around the microkernel. Here we query the thrinfo_t node for the
|
||||
// 1st (ir) loop around the microkernel.
|
||||
thrinfo_t* caucus = bli_thrinfo_sub_node( thread );
|
||||
|
||||
// Query the number of threads and thread ids for each loop.
|
||||
dim_t jr_nt = bli_thread_n_way( thread );
|
||||
dim_t jr_tid = bli_thread_work_id( thread );
|
||||
dim_t ir_nt = bli_thread_n_way( caucus );
|
||||
dim_t ir_tid = bli_thread_work_id( caucus );
|
||||
|
||||
dim_t jr_start, jr_end;
|
||||
dim_t ir_start, ir_end;
|
||||
dim_t jr_inc, ir_inc;
|
||||
|
||||
// Determine the thread range and increment for the 2nd and 1st loops.
|
||||
// NOTE: The definition of bli_thread_range_jrir() will depend on whether
|
||||
// slab or round-robin partitioning was requested at configure-time.
|
||||
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc );
|
||||
bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc );
|
||||
|
||||
// Loop over the n dimension (NR columns at a time).
|
||||
for ( dim_t j = jr_start; j < jr_end; j += jr_inc )
|
||||
{
|
||||
char* b1 = b_cast + j * cstep_b;
|
||||
char* c1 = c_cast + j * cstep_c;
|
||||
|
||||
dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left );
|
||||
|
||||
// Initialize our next panel of B to be the current panel of B.
|
||||
char* b2 = b1;
|
||||
|
||||
// Loop over the m dimension (MR rows at a time).
|
||||
for ( dim_t i = ir_start; i < ir_end; i += ir_inc )
|
||||
{
|
||||
char* a1 = a_cast + i * rstep_a;
|
||||
char* c11 = c1 + i * rstep_c;
|
||||
|
||||
dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left );
|
||||
|
||||
// Compute the addresses of the next panels of A and B.
|
||||
char* a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc );
|
||||
if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) )
|
||||
{
|
||||
a2 = a_cast;
|
||||
b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc );
|
||||
if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) )
|
||||
b2 = b_cast;
|
||||
}
|
||||
|
||||
// Save addresses of next panels of A and B to the auxinfo_t
|
||||
// object.
|
||||
bli_auxinfo_set_next_a( a2, &aux );
|
||||
bli_auxinfo_set_next_b( b2, &aux );
|
||||
|
||||
// Edge case handling now occurs within the microkernel itself, but
|
||||
// we must still explicitly accumulate to a temporary microtile in
|
||||
// situations where a virtual microkernel is being used, such as
|
||||
// during the 1m method or some cases of mixed datatypes.
|
||||
if ( dt_exec == dt_c )
|
||||
{
|
||||
// Invoke the gemm micro-kernel.
|
||||
gemm_ukr
|
||||
(
|
||||
m_cur,
|
||||
n_cur,
|
||||
k,
|
||||
alpha_cast,
|
||||
a1,
|
||||
b1,
|
||||
beta_cast,
|
||||
c11, rs_c, cs_c,
|
||||
&aux,
|
||||
cntx
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Invoke the gemm micro-kernel.
|
||||
gemm_ukr
|
||||
(
|
||||
MR,
|
||||
NR,
|
||||
k,
|
||||
alpha_cast,
|
||||
a1,
|
||||
b1,
|
||||
zero,
|
||||
&ct, rs_ct, cs_ct,
|
||||
&aux,
|
||||
cntx
|
||||
);
|
||||
|
||||
// Accumulate to C with type-casting.
|
||||
xbpys_mxn[ dt_exec ][ dt_c ]
|
||||
(
|
||||
m_cur, n_cur,
|
||||
&ct, rs_ct, cs_ct,
|
||||
beta_cast,
|
||||
c11, rs_c, cs_c
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#undef GENTFUNC
|
||||
#define GENTFUNC( ctype, ch, varname ) \
|
||||
\
|
||||
void PASTEMAC(ch,varname) \
|
||||
( \
|
||||
pack_t schema_a, \
|
||||
pack_t schema_b, \
|
||||
dim_t m, \
|
||||
dim_t n, \
|
||||
dim_t k, \
|
||||
void* alpha, \
|
||||
void* a, inc_t cs_a, inc_t is_a, \
|
||||
dim_t pd_a, inc_t ps_a, \
|
||||
void* b, inc_t rs_b, inc_t is_b, \
|
||||
dim_t pd_b, inc_t ps_b, \
|
||||
void* beta, \
|
||||
void* c, inc_t rs_c, inc_t cs_c, \
|
||||
cntx_t* cntx, \
|
||||
rntm_t* rntm, \
|
||||
thrinfo_t* thread \
|
||||
) \
|
||||
{ \
|
||||
const num_t dt = PASTEMAC(ch,type); \
|
||||
\
|
||||
/* Alias some constants to simpler names. */ \
|
||||
const dim_t MR = pd_a; \
|
||||
const dim_t NR = pd_b; \
|
||||
/*const dim_t PACKMR = cs_a;*/ \
|
||||
/*const dim_t PACKNR = rs_b;*/ \
|
||||
\
|
||||
/* Query the context for the micro-kernel address and cast it to its
|
||||
function pointer type. Note that the virtual gemm ukernel is queried
|
||||
instead of the native gemm ukernel. This is needed for certain
|
||||
situations for the 1m method that require an extra layer of logic
|
||||
to allow for handling (for example) complex values of beta. Also
|
||||
note that under certain circumstances, the real-domain version of
|
||||
this macrokernel will be called for 1m (NOT the complex version)
|
||||
as an optimization. In these cases, the corresponding real-domain
|
||||
slots within the cntx_t's virtual gemm ukernel func_t will contain
|
||||
pointers to the *native* gemm ukernel, thanks to logic in the
|
||||
context initialization function for the induced method (defined
|
||||
in bli_cntx_ref.c). */ \
|
||||
PASTECH(ch,gemm_ukr_ft) \
|
||||
gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
||||
\
|
||||
/* Temporary C buffer for edge cases. Note that the strides of this
|
||||
temporary buffer are set so that they match the storage of the
|
||||
original C matrix. For example, if C is column-stored, ct will be
|
||||
column-stored as well. */ \
|
||||
ctype ct[ BLIS_STACK_BUF_MAX_SIZE \
|
||||
/ sizeof( ctype ) ] \
|
||||
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
|
||||
const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
||||
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
||||
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
||||
\
|
||||
ctype* restrict zero = PASTEMAC(ch,0); \
|
||||
ctype* restrict a_cast = a; \
|
||||
ctype* restrict b_cast = b; \
|
||||
ctype* restrict c_cast = c; \
|
||||
ctype* restrict alpha_cast = alpha; \
|
||||
ctype* restrict beta_cast = beta; \
|
||||
ctype* restrict b1; \
|
||||
ctype* restrict c1; \
|
||||
\
|
||||
dim_t m_iter, m_left; \
|
||||
dim_t n_iter, n_left; \
|
||||
dim_t i, j; \
|
||||
dim_t m_cur; \
|
||||
dim_t n_cur; \
|
||||
inc_t rstep_a; \
|
||||
inc_t cstep_b; \
|
||||
inc_t rstep_c, cstep_c; \
|
||||
auxinfo_t aux; \
|
||||
\
|
||||
/*
|
||||
Assumptions/assertions:
|
||||
rs_a == 1
|
||||
cs_a == PACKMR
|
||||
pd_a == MR
|
||||
ps_a == stride to next micro-panel of A
|
||||
rs_b == PACKNR
|
||||
cs_b == 1
|
||||
pd_b == NR
|
||||
ps_b == stride to next micro-panel of B
|
||||
rs_c == (no assumptions)
|
||||
cs_c == (no assumptions)
|
||||
*/ \
|
||||
\
|
||||
/* If any dimension is zero, return immediately. */ \
|
||||
if ( bli_zero_dim3( m, n, k ) ) return; \
|
||||
\
|
||||
/* Clear the temporary C buffer in case it has any infs or NaNs. */ \
|
||||
PASTEMAC(ch,set0s_mxn)( MR, NR, \
|
||||
ct, rs_ct, cs_ct ); \
|
||||
\
|
||||
/* Compute number of primary and leftover components of the m and n
|
||||
dimensions. */ \
|
||||
n_iter = n / NR; \
|
||||
n_left = n % NR; \
|
||||
\
|
||||
m_iter = m / MR; \
|
||||
m_left = m % MR; \
|
||||
\
|
||||
if ( n_left ) ++n_iter; \
|
||||
if ( m_left ) ++m_iter; \
|
||||
\
|
||||
/* Determine some increments used to step through A, B, and C. */ \
|
||||
rstep_a = ps_a; \
|
||||
\
|
||||
cstep_b = ps_b; \
|
||||
\
|
||||
rstep_c = rs_c * MR; \
|
||||
cstep_c = cs_c * NR; \
|
||||
\
|
||||
/* Save the pack schemas of A and B to the auxinfo_t object. */ \
|
||||
bli_auxinfo_set_schema_a( schema_a, &aux ); \
|
||||
bli_auxinfo_set_schema_b( schema_b, &aux ); \
|
||||
\
|
||||
/* Save the imaginary stride of A and B to the auxinfo_t object. */ \
|
||||
bli_auxinfo_set_is_a( is_a, &aux ); \
|
||||
bli_auxinfo_set_is_b( is_b, &aux ); \
|
||||
\
|
||||
/* The 'thread' argument points to the thrinfo_t node for the 2nd (jr)
|
||||
loop around the microkernel. Here we query the thrinfo_t node for the
|
||||
1st (ir) loop around the microkernel. */ \
|
||||
thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \
|
||||
\
|
||||
/* Query the number of threads and thread ids for each loop. */ \
|
||||
dim_t jr_nt = bli_thread_n_way( thread ); \
|
||||
dim_t jr_tid = bli_thread_work_id( thread ); \
|
||||
dim_t ir_nt = bli_thread_n_way( caucus ); \
|
||||
dim_t ir_tid = bli_thread_work_id( caucus ); \
|
||||
\
|
||||
dim_t jr_start, jr_end; \
|
||||
dim_t ir_start, ir_end; \
|
||||
dim_t jr_inc, ir_inc; \
|
||||
\
|
||||
/* Determine the thread range and increment for the 2nd and 1st loops.
|
||||
NOTE: The definition of bli_thread_range_jrir() will depend on whether
|
||||
slab or round-robin partitioning was requested at configure-time. */ \
|
||||
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \
|
||||
bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \
|
||||
\
|
||||
/* Loop over the n dimension (NR columns at a time). */ \
|
||||
for ( j = jr_start; j < jr_end; j += jr_inc ) \
|
||||
{ \
|
||||
ctype* restrict a1; \
|
||||
ctype* restrict c11; \
|
||||
ctype* restrict b2; \
|
||||
\
|
||||
b1 = b_cast + j * cstep_b; \
|
||||
c1 = c_cast + j * cstep_c; \
|
||||
\
|
||||
n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \
|
||||
\
|
||||
/* Initialize our next panel of B to be the current panel of B. */ \
|
||||
b2 = b1; \
|
||||
\
|
||||
/* Loop over the m dimension (MR rows at a time). */ \
|
||||
for ( i = ir_start; i < ir_end; i += ir_inc ) \
|
||||
{ \
|
||||
ctype* restrict a2; \
|
||||
\
|
||||
a1 = a_cast + i * rstep_a; \
|
||||
c11 = c1 + i * rstep_c; \
|
||||
\
|
||||
m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \
|
||||
\
|
||||
/* Compute the addresses of the next panels of A and B. */ \
|
||||
a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); \
|
||||
if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \
|
||||
{ \
|
||||
a2 = a_cast; \
|
||||
b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); \
|
||||
if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \
|
||||
b2 = b_cast; \
|
||||
} \
|
||||
\
|
||||
/* Save addresses of next panels of A and B to the auxinfo_t
|
||||
object. */ \
|
||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
||||
\
|
||||
/* Handle interior and edge cases separately. */ \
|
||||
if ( m_cur == MR && n_cur == NR ) \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
zero, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
/* Scale the bottom edge of C and add the result from above. */ \
|
||||
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c ); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
/*
|
||||
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \
|
||||
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" ); \
|
||||
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: c after", m_cur, n_cur, c11, rs_c, cs_c, "%4.1f", "" ); \
|
||||
*/ \
|
||||
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" );
|
||||
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" );
|
||||
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: c after", m_cur, n_cur, c11, rs_c, cs_c, "%4.1f", "" );
|
||||
*/
|
||||
}
|
||||
|
||||
INSERT_GENTFUNC_BASIC0( gemm_ker_var2 )
|
||||
|
||||
|
||||
@@ -1,406 +0,0 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2014, The University of Texas at Austin
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include "blis.h"
|
||||
|
||||
#ifdef BLIS_ENABLE_GEMM_MD
|
||||
|
||||
#define FUNCPTR_T gemm_fp
|
||||
|
||||
typedef void (*FUNCPTR_T)
|
||||
(
|
||||
pack_t schema_a,
|
||||
pack_t schema_b,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
void* alpha,
|
||||
void* a, inc_t cs_a, inc_t is_a,
|
||||
dim_t pd_a, inc_t ps_a,
|
||||
void* b, inc_t rs_b, inc_t is_b,
|
||||
dim_t pd_b, inc_t ps_b,
|
||||
void* beta,
|
||||
void* c, inc_t rs_c, inc_t cs_c,
|
||||
cntx_t* cntx,
|
||||
rntm_t* rntm,
|
||||
thrinfo_t* thread
|
||||
);
|
||||
|
||||
static FUNCPTR_T GENARRAY2_ALL(ftypes,gemm_ker_var2_md);
|
||||
|
||||
|
||||
void bli_gemm_ker_var2_md
|
||||
(
|
||||
obj_t* a,
|
||||
obj_t* b,
|
||||
obj_t* c,
|
||||
cntx_t* cntx,
|
||||
rntm_t* rntm,
|
||||
cntl_t* cntl,
|
||||
thrinfo_t* thread
|
||||
)
|
||||
{
|
||||
num_t dt_exec = bli_obj_exec_dt( c );
|
||||
num_t dt_c = bli_obj_dt( c );
|
||||
|
||||
pack_t schema_a = bli_obj_pack_schema( a );
|
||||
pack_t schema_b = bli_obj_pack_schema( b );
|
||||
|
||||
dim_t m = bli_obj_length( c );
|
||||
dim_t n = bli_obj_width( c );
|
||||
dim_t k = bli_obj_width( a );
|
||||
|
||||
void* buf_a = bli_obj_buffer_at_off( a );
|
||||
inc_t cs_a = bli_obj_col_stride( a );
|
||||
inc_t is_a = bli_obj_imag_stride( a );
|
||||
dim_t pd_a = bli_obj_panel_dim( a );
|
||||
inc_t ps_a = bli_obj_panel_stride( a );
|
||||
|
||||
void* buf_b = bli_obj_buffer_at_off( b );
|
||||
inc_t rs_b = bli_obj_row_stride( b );
|
||||
inc_t is_b = bli_obj_imag_stride( b );
|
||||
dim_t pd_b = bli_obj_panel_dim( b );
|
||||
inc_t ps_b = bli_obj_panel_stride( b );
|
||||
|
||||
void* buf_c = bli_obj_buffer_at_off( c );
|
||||
inc_t rs_c = bli_obj_row_stride( c );
|
||||
inc_t cs_c = bli_obj_col_stride( c );
|
||||
|
||||
obj_t scalar_a;
|
||||
obj_t scalar_b;
|
||||
|
||||
void* buf_alpha;
|
||||
void* buf_beta;
|
||||
|
||||
FUNCPTR_T f;
|
||||
|
||||
// Detach and multiply the scalars attached to A and B.
|
||||
// NOTE: We know that the internal scalars of A and B are already of the
|
||||
// target datatypes because the necessary typecasting would have already
|
||||
// taken place during bli_packm_init().
|
||||
bli_obj_scalar_detach( a, &scalar_a );
|
||||
bli_obj_scalar_detach( b, &scalar_b );
|
||||
bli_mulsc( &scalar_a, &scalar_b );
|
||||
|
||||
// Grab the addresses of the internal scalar buffers for the scalar
|
||||
// merged above and the scalar attached to C.
|
||||
// NOTE: We know that scalar_b is of type dt_exec due to the above code
|
||||
// that casts the scalars of A and B to dt_exec via scalar_a and scalar_b,
|
||||
// and we know that the internal scalar in C is already of the type dt_c
|
||||
// due to the casting in the implementation of bli_obj_scalar_attach().
|
||||
buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b );
|
||||
buf_beta = bli_obj_internal_scalar_buffer( c );
|
||||
|
||||
#if 0
|
||||
// NOTE: Turns out that this optimization will never be employed since
|
||||
// currently bli_gemm_ker_var2_md() is only called when the storage
|
||||
// datatype of C differs from the execution/computation datatype, and
|
||||
// this optimization would only make sense if they are equal.
|
||||
|
||||
// If 1m is being employed on a column- or row-stored matrix with a
|
||||
// real-valued beta, we can use the real domain macro-kernel, which
|
||||
// eliminates a little overhead associated with the 1m virtual
|
||||
// micro-kernel.
|
||||
if ( bli_cntx_method( cntx ) == BLIS_1M )
|
||||
{
|
||||
// Only employ this optimization if the storage datatype of C is
|
||||
// equal to the execution/computation datatype.
|
||||
if ( dt_c == dt_exec )
|
||||
{
|
||||
bli_gemm_ind_recast_1m_params
|
||||
(
|
||||
&dt_exec,
|
||||
schema_a,
|
||||
c,
|
||||
&m, &n, &k,
|
||||
&pd_a, &ps_a,
|
||||
&pd_b, &ps_b,
|
||||
&rs_c, &cs_c
|
||||
);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// Tweak parameters in select mixed domain cases (rcc, crc, ccr).
|
||||
bli_gemm_md_ker_var2_recast
|
||||
(
|
||||
&dt_exec,
|
||||
bli_obj_dt( a ),
|
||||
bli_obj_dt( b ),
|
||||
bli_obj_dt( c ),
|
||||
&m, &n, &k,
|
||||
&pd_a, &ps_a,
|
||||
&pd_b, &ps_b,
|
||||
c,
|
||||
&rs_c, &cs_c
|
||||
);
|
||||
|
||||
// Index into the type combination array to extract the correct
|
||||
// function pointer.
|
||||
f = ftypes[dt_c][dt_exec];
|
||||
|
||||
// Invoke the function.
|
||||
f( schema_a,
|
||||
schema_b,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
buf_alpha,
|
||||
buf_a, cs_a, is_a,
|
||||
pd_a, ps_a,
|
||||
buf_b, rs_b, is_b,
|
||||
pd_b, ps_b,
|
||||
buf_beta,
|
||||
buf_c, rs_c, cs_c,
|
||||
cntx,
|
||||
rntm,
|
||||
thread );
|
||||
}
|
||||
|
||||
|
||||
#undef GENTFUNC2
|
||||
#define GENTFUNC2( ctype_c, ctype_e, chc, che, varname ) \
|
||||
\
|
||||
void PASTEMAC2(chc,che,varname) \
|
||||
( \
|
||||
pack_t schema_a, \
|
||||
pack_t schema_b, \
|
||||
dim_t m, \
|
||||
dim_t n, \
|
||||
dim_t k, \
|
||||
void* alpha, \
|
||||
void* a, inc_t cs_a, inc_t is_a, \
|
||||
dim_t pd_a, inc_t ps_a, \
|
||||
void* b, inc_t rs_b, inc_t is_b, \
|
||||
dim_t pd_b, inc_t ps_b, \
|
||||
void* beta, \
|
||||
void* c, inc_t rs_c, inc_t cs_c, \
|
||||
cntx_t* cntx, \
|
||||
rntm_t* rntm, \
|
||||
thrinfo_t* thread \
|
||||
) \
|
||||
{ \
|
||||
const num_t dte = PASTEMAC(che,type); \
|
||||
/*const num_t dtc = PASTEMAC(chc,type);*/ \
|
||||
\
|
||||
/* Alias some constants to simpler names. */ \
|
||||
const dim_t MR = pd_a; \
|
||||
const dim_t NR = pd_b; \
|
||||
/*const dim_t PACKMR = cs_a;*/ \
|
||||
/*const dim_t PACKNR = rs_b;*/ \
|
||||
\
|
||||
/* Query the context for the micro-kernel address and cast it to its
|
||||
function pointer type. */ \
|
||||
PASTECH(che,gemm_ukr_ft) \
|
||||
gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dte, BLIS_GEMM_UKR, cntx ); \
|
||||
\
|
||||
/* Temporary C buffer for edge cases. Note that the strides of this
|
||||
temporary buffer are set so that they match the storage of the
|
||||
original C matrix. For example, if C is column-stored, ct will be
|
||||
column-stored as well. */ \
|
||||
ctype_e ct[ BLIS_STACK_BUF_MAX_SIZE \
|
||||
/ sizeof( ctype_e ) ] \
|
||||
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
|
||||
const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dte, BLIS_GEMM_UKR, cntx ); \
|
||||
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
||||
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
||||
\
|
||||
ctype_e* restrict zero = PASTEMAC(che,0); \
|
||||
ctype_e* restrict a_cast = a; \
|
||||
ctype_e* restrict b_cast = b; \
|
||||
ctype_c* restrict c_cast = c; \
|
||||
ctype_e* restrict alpha_cast = alpha; \
|
||||
ctype_c* restrict beta_cast = beta; \
|
||||
ctype_e* restrict b1; \
|
||||
ctype_c* restrict c1; \
|
||||
\
|
||||
dim_t m_iter, m_left; \
|
||||
dim_t n_iter, n_left; \
|
||||
dim_t i, j; \
|
||||
dim_t m_cur; \
|
||||
dim_t n_cur; \
|
||||
inc_t rstep_a; \
|
||||
inc_t cstep_b; \
|
||||
inc_t rstep_c, cstep_c; \
|
||||
auxinfo_t aux; \
|
||||
\
|
||||
/*
|
||||
Assumptions/assertions:
|
||||
rs_a == 1
|
||||
cs_a == PACKMR
|
||||
pd_a == MR
|
||||
ps_a == stride to next micro-panel of A
|
||||
rs_b == PACKNR
|
||||
cs_b == 1
|
||||
pd_b == NR
|
||||
ps_b == stride to next micro-panel of B
|
||||
rs_c == (no assumptions)
|
||||
cs_c == (no assumptions)
|
||||
*/ \
|
||||
\
|
||||
/* If any dimension is zero, return immediately. */ \
|
||||
if ( bli_zero_dim3( m, n, k ) ) return; \
|
||||
\
|
||||
/* Clear the temporary C buffer in case it has any infs or NaNs. */ \
|
||||
PASTEMAC(che,set0s_mxn)( MR, NR, \
|
||||
ct, rs_ct, cs_ct ); \
|
||||
\
|
||||
/* Compute number of primary and leftover components of the m and n
|
||||
dimensions. */ \
|
||||
n_iter = n / NR; \
|
||||
n_left = n % NR; \
|
||||
\
|
||||
m_iter = m / MR; \
|
||||
m_left = m % MR; \
|
||||
\
|
||||
if ( n_left ) ++n_iter; \
|
||||
if ( m_left ) ++m_iter; \
|
||||
\
|
||||
/* Determine some increments used to step through A, B, and C. */ \
|
||||
rstep_a = ps_a; \
|
||||
\
|
||||
cstep_b = ps_b; \
|
||||
\
|
||||
rstep_c = rs_c * MR; \
|
||||
cstep_c = cs_c * NR; \
|
||||
\
|
||||
/* Save the pack schemas of A and B to the auxinfo_t object. */ \
|
||||
bli_auxinfo_set_schema_a( schema_a, &aux ); \
|
||||
bli_auxinfo_set_schema_b( schema_b, &aux ); \
|
||||
\
|
||||
/* Save the imaginary stride of A and B to the auxinfo_t object. */ \
|
||||
bli_auxinfo_set_is_a( is_a, &aux ); \
|
||||
bli_auxinfo_set_is_b( is_b, &aux ); \
|
||||
\
|
||||
/* The 'thread' argument points to the thrinfo_t node for the 2nd (jr)
|
||||
loop around the microkernel. Here we query the thrinfo_t node for the
|
||||
1st (ir) loop around the microkernel. */ \
|
||||
thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \
|
||||
\
|
||||
/* Query the number of threads and thread ids for each loop. */ \
|
||||
dim_t jr_nt = bli_thread_n_way( thread ); \
|
||||
dim_t jr_tid = bli_thread_work_id( thread ); \
|
||||
dim_t ir_nt = bli_thread_n_way( caucus ); \
|
||||
dim_t ir_tid = bli_thread_work_id( caucus ); \
|
||||
\
|
||||
dim_t jr_start, jr_end; \
|
||||
dim_t ir_start, ir_end; \
|
||||
dim_t jr_inc, ir_inc; \
|
||||
\
|
||||
/* Determine the thread range and increment for the 2nd and 1st loops.
|
||||
NOTE: The definition of bli_thread_range_jrir() will depend on whether
|
||||
slab or round-robin partitioning was requested at configure-time. */ \
|
||||
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \
|
||||
bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \
|
||||
\
|
||||
/* Loop over the n dimension (NR columns at a time). */ \
|
||||
for ( j = jr_start; j < jr_end; j += jr_inc ) \
|
||||
{ \
|
||||
ctype_e* restrict a1; \
|
||||
ctype_c* restrict c11; \
|
||||
ctype_e* restrict b2; \
|
||||
\
|
||||
b1 = b_cast + j * cstep_b; \
|
||||
c1 = c_cast + j * cstep_c; \
|
||||
\
|
||||
n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \
|
||||
\
|
||||
/* Initialize our next panel of B to be the current panel of B. */ \
|
||||
b2 = b1; \
|
||||
\
|
||||
/* Loop over the m dimension (MR rows at a time). */ \
|
||||
for ( i = ir_start; i < ir_end; i += ir_inc ) \
|
||||
{ \
|
||||
ctype_e* restrict a2; \
|
||||
\
|
||||
a1 = a_cast + i * rstep_a; \
|
||||
c11 = c1 + i * rstep_c; \
|
||||
\
|
||||
m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \
|
||||
\
|
||||
/* Compute the addresses of the next panels of A and B. */ \
|
||||
a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); \
|
||||
if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \
|
||||
{ \
|
||||
a2 = a_cast; \
|
||||
b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); \
|
||||
if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \
|
||||
b2 = b_cast; \
|
||||
} \
|
||||
\
|
||||
/* Save addresses of next panels of A and B to the auxinfo_t
|
||||
object. */ \
|
||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
||||
\
|
||||
/* Always save the micropanel product to the local microtile and
|
||||
then accumulate it into C via the xpbys_mxn macro. */ \
|
||||
/*if ( 1 )*/ \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
zero, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
/* Scale the microtile of C and add the result from above. */ \
|
||||
PASTEMAC3(che,chc,chc,xpbys_mxn) \
|
||||
( \
|
||||
m_cur, n_cur, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c \
|
||||
); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
/*
|
||||
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \
|
||||
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" ); \
|
||||
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: c after", m_cur, n_cur, c11, rs_c, cs_c, "%4.1f", "" ); \
|
||||
*/ \
|
||||
}
|
||||
|
||||
INSERT_GENTFUNC2_BASIC0( gemm_ker_var2_md )
|
||||
INSERT_GENTFUNC2_MIXDP0( gemm_ker_var2_md )
|
||||
|
||||
#endif
|
||||
@@ -154,7 +154,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast
|
||||
num_t* dt_comp,
|
||||
num_t dt_a,
|
||||
num_t dt_b,
|
||||
num_t dt_c,
|
||||
num_t* dt_c,
|
||||
dim_t* m,
|
||||
dim_t* n,
|
||||
dim_t* k,
|
||||
@@ -164,7 +164,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast
|
||||
inc_t* rs_c, inc_t* cs_c
|
||||
)
|
||||
{
|
||||
if ( bli_is_real( dt_c ) &&
|
||||
if ( bli_is_real( *dt_c ) &&
|
||||
bli_is_complex( dt_a ) &&
|
||||
bli_is_complex( dt_b ) )
|
||||
{
|
||||
@@ -177,7 +177,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast
|
||||
*ps_a *= 2;
|
||||
*ps_b *= 2;
|
||||
}
|
||||
else if ( bli_is_complex( dt_c ) &&
|
||||
else if ( bli_is_complex( *dt_c ) &&
|
||||
bli_is_real( dt_a ) &&
|
||||
bli_is_complex( dt_b ) )
|
||||
{
|
||||
@@ -197,6 +197,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast
|
||||
// to the real virtual microkernel slots of the context) instead of
|
||||
// the complex macrokernel and c2r virtual microkernel.
|
||||
*dt_comp = bli_dt_proj_to_real( *dt_comp );
|
||||
*dt_c = bli_dt_proj_to_real( *dt_c );
|
||||
*n *= 2;
|
||||
*pd_b *= 2; *ps_b *= 2;
|
||||
*rs_c *= 2;
|
||||
@@ -211,7 +212,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast
|
||||
*ps_a /= 2;
|
||||
}
|
||||
}
|
||||
else if ( bli_is_complex( dt_c ) &&
|
||||
else if ( bli_is_complex( *dt_c ) &&
|
||||
bli_is_complex( dt_a ) &&
|
||||
bli_is_real( dt_b ) )
|
||||
{
|
||||
@@ -231,6 +232,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast
|
||||
// to the real virtual microkernel slots of the context) instead of
|
||||
// the complex macrokernel and c2r virtual microkernel.
|
||||
*dt_comp = bli_dt_proj_to_real( *dt_comp );
|
||||
*dt_c = bli_dt_proj_to_real( *dt_c );
|
||||
*m *= 2;
|
||||
*pd_a *= 2; *ps_a *= 2;
|
||||
*cs_c *= 2;
|
||||
@@ -274,54 +276,3 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast
|
||||
#endif
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
//
|
||||
// Prototype object-based interfaces.
|
||||
//
|
||||
|
||||
#undef GENPROT
|
||||
#define GENPROT( opname ) \
|
||||
\
|
||||
void PASTEMAC0(opname) \
|
||||
( \
|
||||
obj_t* a, \
|
||||
obj_t* b, \
|
||||
obj_t* c, \
|
||||
cntx_t* cntx, \
|
||||
rntm_t* rntm, \
|
||||
cntl_t* cntl, \
|
||||
thrinfo_t* thread \
|
||||
);
|
||||
|
||||
GENPROT( gemm_ker_var2_md )
|
||||
|
||||
//
|
||||
// Prototype BLAS-like interfaces with void pointer operands.
|
||||
//
|
||||
|
||||
#undef GENTPROT2
|
||||
#define GENTPROT2( ctype_c, ctype_e, chc, che, varname ) \
|
||||
\
|
||||
void PASTEMAC2(chc,che,varname) \
|
||||
( \
|
||||
pack_t schema_a, \
|
||||
pack_t schema_b, \
|
||||
dim_t m, \
|
||||
dim_t n, \
|
||||
dim_t k, \
|
||||
void* alpha, \
|
||||
void* a, inc_t cs_a, inc_t is_a, \
|
||||
dim_t pd_a, inc_t ps_a, \
|
||||
void* b, inc_t rs_b, inc_t is_b, \
|
||||
dim_t pd_b, inc_t ps_b, \
|
||||
void* beta, \
|
||||
void* c, inc_t rs_c, inc_t cs_c, \
|
||||
cntx_t* cntx, \
|
||||
rntm_t* rntm, \
|
||||
thrinfo_t* thread \
|
||||
);
|
||||
|
||||
INSERT_GENTPROT2_BASIC0( gemm_ker_var2_md )
|
||||
INSERT_GENTPROT2_MIXDP0( gemm_ker_var2_md )
|
||||
|
||||
|
||||
@@ -41,6 +41,8 @@
|
||||
\
|
||||
void PASTEMAC2(ch,opname,suf) \
|
||||
( \
|
||||
dim_t m, \
|
||||
dim_t n, \
|
||||
dim_t k, \
|
||||
ctype* restrict alpha, \
|
||||
ctype* restrict a, \
|
||||
@@ -61,6 +63,9 @@ void PASTEMAC2(ch,opname,suf) \
|
||||
\
|
||||
const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \
|
||||
const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \
|
||||
\
|
||||
dim_t mr_r = mr; \
|
||||
dim_t nr_r = nr; \
|
||||
\
|
||||
ctype ct[ BLIS_STACK_BUF_MAX_SIZE \
|
||||
/ sizeof( ctype_r ) ] \
|
||||
@@ -81,6 +86,9 @@ void PASTEMAC2(ch,opname,suf) \
|
||||
\
|
||||
ctype_r* restrict beta_r = &PASTEMAC(ch,real)( *beta ); \
|
||||
ctype_r* restrict beta_i = &PASTEMAC(ch,imag)( *beta ); \
|
||||
\
|
||||
dim_t m_use; \
|
||||
dim_t n_use; \
|
||||
\
|
||||
ctype_r* c_use; \
|
||||
inc_t rs_c_use; \
|
||||
@@ -146,17 +154,16 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \
|
||||
rs_c_use = rs_ct; \
|
||||
cs_c_use = cs_ct; \
|
||||
\
|
||||
/* Convert the strides from being in units of complex elements to
|
||||
be in units of real elements. Note that we don't need to check for
|
||||
general storage here because that case corresponds to the scenario
|
||||
where we are using the ct buffer and its rs_ct/cs_ct strides. */ \
|
||||
if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) cs_c_use *= 2; \
|
||||
else rs_c_use *= 2; \
|
||||
\
|
||||
/* Convert the strides and corresponding microtile dimension from being
|
||||
in units of complex elements to be in units of real elements. */ \
|
||||
if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) { cs_c_use *= 2; mr_r *= 2; } \
|
||||
else { rs_c_use *= 2; nr_r *= 2; }\
|
||||
\
|
||||
/* c = beta * c + alpha_r * a * b; */ \
|
||||
rgemm_ukr \
|
||||
( \
|
||||
mr_r, \
|
||||
nr_r, \
|
||||
k, \
|
||||
alpha_r, \
|
||||
a_r, \
|
||||
@@ -166,14 +173,12 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \
|
||||
data, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
dim_t i, j; \
|
||||
\
|
||||
/* Accumulate the final result in ct back to c. */ \
|
||||
if ( PASTEMAC(ch,eq1)( *beta ) ) \
|
||||
{ \
|
||||
for ( j = 0; j < nr; ++j ) \
|
||||
for ( i = 0; i < mr; ++i ) \
|
||||
for ( dim_t j = 0; j < n; ++j ) \
|
||||
for ( dim_t i = 0; i < m; ++i ) \
|
||||
{ \
|
||||
PASTEMAC(ch,adds)( *(ct + i*rs_ct + j*cs_ct), \
|
||||
*(c + i*rs_c + j*cs_c ) ); \
|
||||
@@ -181,8 +186,8 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \
|
||||
} \
|
||||
else if ( PASTEMAC(ch,eq0)( *beta ) ) \
|
||||
{ \
|
||||
for ( j = 0; j < nr; ++j ) \
|
||||
for ( i = 0; i < mr; ++i ) \
|
||||
for ( dim_t j = 0; j < n; ++j ) \
|
||||
for ( dim_t i = 0; i < m; ++i ) \
|
||||
{ \
|
||||
PASTEMAC(ch,copys)( *(ct + i*rs_ct + j*cs_ct), \
|
||||
*(c + i*rs_c + j*cs_c ) ); \
|
||||
@@ -190,8 +195,8 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
for ( j = 0; j < nr; ++j ) \
|
||||
for ( i = 0; i < mr; ++i ) \
|
||||
for ( dim_t j = 0; j < n; ++j ) \
|
||||
for ( dim_t i = 0; i < m; ++i ) \
|
||||
{ \
|
||||
PASTEMAC(ch,xpbys)( *(ct + i*rs_ct + j*cs_ct), \
|
||||
*beta, \
|
||||
@@ -207,17 +212,19 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \
|
||||
c_use = ( ctype_r* )c; \
|
||||
rs_c_use = rs_c; \
|
||||
cs_c_use = cs_c; \
|
||||
m_use = m; \
|
||||
n_use = n; \
|
||||
\
|
||||
/* Convert the strides from being in units of complex elements to
|
||||
be in units of real elements. Note that we don't need to check for
|
||||
general storage here because that case corresponds to the scenario
|
||||
where we are using the ct buffer and its rs_ct/cs_ct strides. */ \
|
||||
if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) cs_c_use *= 2; \
|
||||
else rs_c_use *= 2; \
|
||||
/* Convert the strides and corresponding microtile dimension from being
|
||||
in units of complex elements to be in units of real elements. */ \
|
||||
if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) { cs_c_use *= 2; m_use *= 2; } \
|
||||
else { rs_c_use *= 2; n_use *= 2; } \
|
||||
\
|
||||
/* c = beta * c + alpha_r * a * b; */ \
|
||||
rgemm_ukr \
|
||||
( \
|
||||
m_use, \
|
||||
n_use, \
|
||||
k, \
|
||||
alpha_r, \
|
||||
a_r, \
|
||||
|
||||
@@ -34,6 +34,16 @@
|
||||
*/
|
||||
|
||||
|
||||
//
|
||||
// gemm kernel parameter struct.
|
||||
//
|
||||
|
||||
typedef struct
|
||||
{
|
||||
gemm_ukr_vft ukr;
|
||||
} gemm_ker_params_t;
|
||||
|
||||
|
||||
//
|
||||
// Prototype object-based interfaces.
|
||||
//
|
||||
@@ -59,32 +69,3 @@ GENPROT( gemm_blk_var3 )
|
||||
GENPROT( gemm_ker_var1 )
|
||||
GENPROT( gemm_ker_var2 )
|
||||
|
||||
|
||||
//
|
||||
// Prototype BLAS-like interfaces with void pointer operands.
|
||||
//
|
||||
|
||||
#undef GENTPROT
|
||||
#define GENTPROT( ctype, ch, varname ) \
|
||||
\
|
||||
void PASTEMAC(ch,varname) \
|
||||
( \
|
||||
pack_t schema_a, \
|
||||
pack_t schema_b, \
|
||||
dim_t m, \
|
||||
dim_t n, \
|
||||
dim_t k, \
|
||||
void* alpha, \
|
||||
void* a, inc_t cs_a, inc_t is_a, \
|
||||
dim_t pd_a, inc_t ps_a, \
|
||||
void* b, inc_t rs_b, inc_t is_b, \
|
||||
dim_t pd_b, inc_t ps_b, \
|
||||
void* beta, \
|
||||
void* c, inc_t rs_c, inc_t cs_c, \
|
||||
cntx_t* cntx, \
|
||||
rntm_t* rntm, \
|
||||
thrinfo_t* thread \
|
||||
);
|
||||
|
||||
INSERT_GENTPROT_BASIC0( gemm_ker_var2 )
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@
|
||||
BLIS_INLINE void bli_gemm_ind_recast_1m_params
|
||||
(
|
||||
num_t* dt_exec,
|
||||
num_t* dt_c,
|
||||
pack_t schema_a,
|
||||
obj_t* c,
|
||||
dim_t* m,
|
||||
@@ -57,6 +58,7 @@ BLIS_INLINE void bli_gemm_ind_recast_1m_params
|
||||
!bli_is_gen_stored( *rs_c, *cs_c ) )
|
||||
{
|
||||
*dt_exec = bli_dt_proj_to_real( *dt_exec );
|
||||
*dt_c = bli_dt_proj_to_real( *dt_c );
|
||||
|
||||
if ( bli_is_1e_packed( schema_a ) )
|
||||
{
|
||||
|
||||
@@ -279,6 +279,9 @@ void PASTEMAC(ch,varname) \
|
||||
/* Save the imaginary stride of A and B to the auxinfo_t object. */ \
|
||||
bli_auxinfo_set_is_a( is_a, &aux ); \
|
||||
bli_auxinfo_set_is_b( is_b, &aux ); \
|
||||
\
|
||||
/* Save the desired output datatype (indicating no typecasting). */ \
|
||||
/*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \
|
||||
\
|
||||
/* The 'thread' argument points to the thrinfo_t node for the 2nd (jr)
|
||||
loop around the microkernel. Here we query the thrinfo_t node for the
|
||||
@@ -381,43 +384,20 @@ void PASTEMAC(ch,varname) \
|
||||
And if we're strictly above the diagonal, we do nothing and
|
||||
continue. */ \
|
||||
{ \
|
||||
/* Handle interior and edge cases separately. */ \
|
||||
if ( m_cur == MR && n_cur == NR ) \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
zero, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
/* Scale the edge of C and add the result. */ \
|
||||
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c ); \
|
||||
} \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
m_cur, \
|
||||
n_cur, \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
@@ -490,6 +470,8 @@ void PASTEMAC(ch,varname) \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
MR, \
|
||||
NR, \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
@@ -509,43 +491,20 @@ void PASTEMAC(ch,varname) \
|
||||
} \
|
||||
else if ( bli_is_strictly_below_diag_n( diagoffc_ij, m_cur, n_cur ) ) \
|
||||
{ \
|
||||
/* Handle interior and edge cases separately. */ \
|
||||
if ( m_cur == MR && n_cur == NR ) \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
zero, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
/* Scale the edge of C and add the result. */ \
|
||||
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c ); \
|
||||
} \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
m_cur, \
|
||||
n_cur, \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
|
||||
@@ -281,6 +281,9 @@ void PASTEMAC(ch,varname) \
|
||||
/* Save the imaginary stride of A and B to the auxinfo_t object. */ \
|
||||
bli_auxinfo_set_is_a( is_a, &aux ); \
|
||||
bli_auxinfo_set_is_b( is_b, &aux ); \
|
||||
\
|
||||
/* Save the desired output datatype (indicating no typecasting). */ \
|
||||
/*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \
|
||||
\
|
||||
/* The 'thread' argument points to the thrinfo_t node for the 2nd (jr)
|
||||
loop around the microkernel. Here we query the thrinfo_t node for the
|
||||
@@ -385,6 +388,8 @@ void PASTEMAC(ch,varname) \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
MR, \
|
||||
NR, \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
@@ -404,43 +409,20 @@ void PASTEMAC(ch,varname) \
|
||||
} \
|
||||
else if ( bli_is_strictly_above_diag_n( diagoffc_ij, m_cur, n_cur ) ) \
|
||||
{ \
|
||||
/* Handle interior and edge cases separately. */ \
|
||||
if ( m_cur == MR && n_cur == NR ) \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
zero, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
/* Scale the edge of C and add the result. */ \
|
||||
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c ); \
|
||||
} \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
m_cur, \
|
||||
n_cur, \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
@@ -512,43 +494,20 @@ void PASTEMAC(ch,varname) \
|
||||
And if we're strictly below the diagonal, we do nothing and
|
||||
continue. */ \
|
||||
{ \
|
||||
/* Handle interior and edge cases separately. */ \
|
||||
if ( m_cur == MR && n_cur == NR ) \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
zero, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
/* Scale the edge of C and add the result. */ \
|
||||
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c ); \
|
||||
} \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
m_cur, \
|
||||
n_cur, \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
|
||||
@@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \
|
||||
function pointer type. */ \
|
||||
PASTECH(ch,gemm_ukr_ft) \
|
||||
gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
||||
\
|
||||
/* Temporary C buffer for edge cases. Note that the strides of this
|
||||
temporary buffer are set so that they match the storage of the
|
||||
original C matrix. For example, if C is column-stored, ct will be
|
||||
column-stored as well. */ \
|
||||
ctype ct[ BLIS_STACK_BUF_MAX_SIZE \
|
||||
/ sizeof( ctype ) ] \
|
||||
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
|
||||
const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
||||
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
||||
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
||||
\
|
||||
ctype* restrict one = PASTEMAC(ch,1); \
|
||||
ctype* restrict zero = PASTEMAC(ch,0); \
|
||||
ctype* restrict a_cast = a; \
|
||||
ctype* restrict b_cast = b; \
|
||||
ctype* restrict c_cast = c; \
|
||||
@@ -254,10 +242,6 @@ void PASTEMAC(ch,varname) \
|
||||
diagoffa = 0; \
|
||||
c_cast = c_cast + (i )*rs_c; \
|
||||
} \
|
||||
\
|
||||
/* Clear the temporary C buffer in case it has any infs or NaNs. */ \
|
||||
PASTEMAC(ch,set0s_mxn)( MR, NR, \
|
||||
ct, rs_ct, cs_ct ); \
|
||||
\
|
||||
/* Compute number of primary and leftover components of the m and n
|
||||
dimensions. */ \
|
||||
@@ -307,8 +291,8 @@ void PASTEMAC(ch,varname) \
|
||||
dim_t jr_inc; \
|
||||
\
|
||||
/* Determine the thread range and increment for the 2nd loop.
|
||||
NOTE: The definition of bli_thread_range_jrir() will depend on whether
|
||||
slab or round-robin partitioning was requested at configure-time. \
|
||||
NOTE: The definition of bli_thread_range_jrir() will depend on whether
|
||||
slab or round-robin partitioning was requested at configure-time. \
|
||||
NOTE: Parallelism in the 1st loop is disabled for now. */ \
|
||||
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \
|
||||
/*bli_thread_range_jrir_rr( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc );*/ \
|
||||
@@ -379,47 +363,20 @@ void PASTEMAC(ch,varname) \
|
||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
||||
\
|
||||
/* Handle interior and edge cases separately. */ \
|
||||
if ( m_cur == MR && n_cur == NR ) \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k_a1011, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1_i, \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
/* Copy edge elements of C to the temporary buffer. */ \
|
||||
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
|
||||
c11, rs_c, cs_c, \
|
||||
ct, rs_ct, cs_ct ); \
|
||||
\
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k_a1011, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1_i, \
|
||||
beta_cast, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
/* Copy the result to the edge of C. */ \
|
||||
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
c11, rs_c, cs_c ); \
|
||||
} \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
m_cur, \
|
||||
n_cur, \
|
||||
k_a1011, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1_i, \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
/*}*/ \
|
||||
\
|
||||
a1 += ps_a_cur; \
|
||||
@@ -446,42 +403,20 @@ void PASTEMAC(ch,varname) \
|
||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
||||
\
|
||||
/* Handle interior and edge cases separately. */ \
|
||||
if ( m_cur == MR && n_cur == NR ) \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
one, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
zero, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
/* Add the result to the edge of C. */ \
|
||||
PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
c11, rs_c, cs_c ); \
|
||||
} \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
m_cur, \
|
||||
n_cur, \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
one, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
/*}*/ \
|
||||
\
|
||||
a1 += rstep_a; \
|
||||
|
||||
@@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \
|
||||
function pointer type. */ \
|
||||
PASTECH(ch,gemm_ukr_ft) \
|
||||
gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
||||
\
|
||||
/* Temporary C buffer for edge cases. Note that the strides of this
|
||||
temporary buffer are set so that they match the storage of the
|
||||
original C matrix. For example, if C is column-stored, ct will be
|
||||
column-stored as well. */ \
|
||||
ctype ct[ BLIS_STACK_BUF_MAX_SIZE \
|
||||
/ sizeof( ctype ) ] \
|
||||
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
|
||||
const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
||||
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
||||
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
||||
\
|
||||
ctype* restrict one = PASTEMAC(ch,1); \
|
||||
ctype* restrict zero = PASTEMAC(ch,0); \
|
||||
ctype* restrict a_cast = a; \
|
||||
ctype* restrict b_cast = b; \
|
||||
ctype* restrict c_cast = c; \
|
||||
@@ -261,10 +249,6 @@ void PASTEMAC(ch,varname) \
|
||||
{ \
|
||||
m = -diagoffa + k; \
|
||||
} \
|
||||
\
|
||||
/* Clear the temporary C buffer in case it has any infs or NaNs. */ \
|
||||
PASTEMAC(ch,set0s_mxn)( MR, NR, \
|
||||
ct, rs_ct, cs_ct ); \
|
||||
\
|
||||
/* Compute number of primary and leftover components of the m and n
|
||||
dimensions. */ \
|
||||
@@ -386,47 +370,20 @@ void PASTEMAC(ch,varname) \
|
||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
||||
\
|
||||
/* Handle interior and edge cases separately. */ \
|
||||
if ( m_cur == MR && n_cur == NR ) \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k_a1112, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1_i, \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
/* Copy edge elements of C to the temporary buffer. */ \
|
||||
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
|
||||
c11, rs_c, cs_c, \
|
||||
ct, rs_ct, cs_ct ); \
|
||||
\
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k_a1112, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1_i, \
|
||||
beta_cast, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
/* Copy the result to the edge of C. */ \
|
||||
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
c11, rs_c, cs_c ); \
|
||||
} \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
m_cur, \
|
||||
n_cur, \
|
||||
k_a1112, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1_i, \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
/*}*/ \
|
||||
\
|
||||
a1 += ps_a_cur; \
|
||||
@@ -453,42 +410,20 @@ void PASTEMAC(ch,varname) \
|
||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
||||
\
|
||||
/* Handle interior and edge cases separately. */ \
|
||||
if ( m_cur == MR && n_cur == NR ) \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
one, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
zero, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
/* Add the result to the edge of C. */ \
|
||||
PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
c11, rs_c, cs_c ); \
|
||||
} \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
m_cur, \
|
||||
n_cur, \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
one, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
/*}*/ \
|
||||
\
|
||||
a1 += rstep_a; \
|
||||
|
||||
@@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \
|
||||
function pointer type. */ \
|
||||
PASTECH(ch,gemm_ukr_ft) \
|
||||
gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
||||
\
|
||||
/* Temporary C buffer for edge cases. Note that the strides of this
|
||||
temporary buffer are set so that they match the storage of the
|
||||
original C matrix. For example, if C is column-stored, ct will be
|
||||
column-stored as well. */ \
|
||||
ctype ct[ BLIS_STACK_BUF_MAX_SIZE \
|
||||
/ sizeof( ctype ) ] \
|
||||
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
|
||||
const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
||||
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
||||
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
||||
\
|
||||
ctype* restrict one = PASTEMAC(ch,1); \
|
||||
ctype* restrict zero = PASTEMAC(ch,0); \
|
||||
ctype* restrict a_cast = a; \
|
||||
ctype* restrict b_cast = b; \
|
||||
ctype* restrict c_cast = c; \
|
||||
@@ -261,10 +249,6 @@ void PASTEMAC(ch,varname) \
|
||||
{ \
|
||||
n = diagoffb + k; \
|
||||
} \
|
||||
\
|
||||
/* Clear the temporary C buffer in case it has any infs or NaNs. */ \
|
||||
PASTEMAC(ch,set0s_mxn)( MR, NR, \
|
||||
ct, rs_ct, cs_ct ); \
|
||||
\
|
||||
/* Compute number of primary and leftover components of the m and n
|
||||
dimensions. */ \
|
||||
@@ -335,9 +319,9 @@ void PASTEMAC(ch,varname) \
|
||||
\
|
||||
/* Determine the thread range and increment for the 2nd and 1st loops for
|
||||
the initial rectangular region of B (if it exists).
|
||||
NOTE: The definition of bli_thread_range_jrir() will depend on whether
|
||||
slab or round-robin partitioning was requested at configure-time. \
|
||||
NOTE: Parallelism in the 1st loop is disabled for now. */ \
|
||||
NOTE: The definition of bli_thread_range_jrir() will depend on whether
|
||||
slab or round-robin partitioning was requested at configure-time. \
|
||||
NOTE: Parallelism in the 1st loop is disabled for now. */ \
|
||||
bli_thread_range_jrir( thread, n_iter_rct, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \
|
||||
bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \
|
||||
\
|
||||
@@ -382,42 +366,20 @@ void PASTEMAC(ch,varname) \
|
||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
||||
\
|
||||
/* Handle interior and edge cases separately. */ \
|
||||
if ( m_cur == MR && n_cur == NR ) \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
one, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
zero, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
/* Add the result to the edge of C. */ \
|
||||
PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
c11, rs_c, cs_c ); \
|
||||
} \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
m_cur, \
|
||||
n_cur, \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
one, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
@@ -501,47 +463,20 @@ void PASTEMAC(ch,varname) \
|
||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
||||
\
|
||||
/* Handle interior and edge cases separately. */ \
|
||||
if ( m_cur == MR && n_cur == NR ) \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k_b1121, \
|
||||
alpha_cast, \
|
||||
a1_i, \
|
||||
b1, \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
/* Copy edge elements of C to the temporary buffer. */ \
|
||||
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
|
||||
c11, rs_c, cs_c, \
|
||||
ct, rs_ct, cs_ct ); \
|
||||
\
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k_b1121, \
|
||||
alpha_cast, \
|
||||
a1_i, \
|
||||
b1, \
|
||||
beta_cast, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
/* Copy the result to the edge of C. */ \
|
||||
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
c11, rs_c, cs_c ); \
|
||||
} \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
m_cur, \
|
||||
n_cur, \
|
||||
k_b1121, \
|
||||
alpha_cast, \
|
||||
a1_i, \
|
||||
b1, \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
\
|
||||
a1 += rstep_a; \
|
||||
|
||||
@@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \
|
||||
function pointer type. */ \
|
||||
PASTECH(ch,gemm_ukr_ft) \
|
||||
gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
||||
\
|
||||
/* Temporary C buffer for edge cases. Note that the strides of this
|
||||
temporary buffer are set so that they match the storage of the
|
||||
original C matrix. For example, if C is column-stored, ct will be
|
||||
column-stored as well. */ \
|
||||
ctype ct[ BLIS_STACK_BUF_MAX_SIZE \
|
||||
/ sizeof( ctype ) ] \
|
||||
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
|
||||
const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
||||
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
||||
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
||||
\
|
||||
ctype* restrict one = PASTEMAC(ch,1); \
|
||||
ctype* restrict zero = PASTEMAC(ch,0); \
|
||||
ctype* restrict a_cast = a; \
|
||||
ctype* restrict b_cast = b; \
|
||||
ctype* restrict c_cast = c; \
|
||||
@@ -262,10 +250,6 @@ void PASTEMAC(ch,varname) \
|
||||
{ \
|
||||
k = -diagoffb + n; \
|
||||
} \
|
||||
\
|
||||
/* Clear the temporary C buffer in case it has any infs or NaNs. */ \
|
||||
PASTEMAC(ch,set0s_mxn)( MR, NR, \
|
||||
ct, rs_ct, cs_ct ); \
|
||||
\
|
||||
/* Compute number of primary and leftover components of the m and n
|
||||
dimensions. */ \
|
||||
@@ -410,47 +394,20 @@ void PASTEMAC(ch,varname) \
|
||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
||||
\
|
||||
/* Handle interior and edge cases separately. */ \
|
||||
if ( m_cur == MR && n_cur == NR ) \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k_b0111, \
|
||||
alpha_cast, \
|
||||
a1_i, \
|
||||
b1, \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
/* Copy edge elements of C to the temporary buffer. */ \
|
||||
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
|
||||
c11, rs_c, cs_c, \
|
||||
ct, rs_ct, cs_ct ); \
|
||||
\
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k_b0111, \
|
||||
alpha_cast, \
|
||||
a1_i, \
|
||||
b1, \
|
||||
beta_cast, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
/* Copy the result to the edge of C. */ \
|
||||
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
c11, rs_c, cs_c ); \
|
||||
} \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
m_cur, \
|
||||
n_cur, \
|
||||
k_b0111, \
|
||||
alpha_cast, \
|
||||
a1_i, \
|
||||
b1, \
|
||||
beta_cast, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
\
|
||||
a1 += rstep_a; \
|
||||
@@ -476,9 +433,9 @@ void PASTEMAC(ch,varname) \
|
||||
bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \
|
||||
\
|
||||
/* Advance the start and end iteration offsets for the rectangular region
|
||||
by the number of iterations used for the triangular region. */ \
|
||||
jr_start += n_iter_tri; \
|
||||
jr_end += n_iter_tri; \
|
||||
by the number of iterations used for the triangular region. */ \
|
||||
jr_start += n_iter_tri; \
|
||||
jr_end += n_iter_tri; \
|
||||
jb0 = n_iter_tri; \
|
||||
\
|
||||
/* Save the resulting value of b1 from the previous loop since it represents
|
||||
@@ -496,7 +453,7 @@ void PASTEMAC(ch,varname) \
|
||||
the starting address of the rectangular region (which is already
|
||||
n_iter_tri logical iterations through B). */ \
|
||||
b1 = b_cast + (j-jb0) * cstep_b; \
|
||||
c1 = c_cast + j * cstep_c; \
|
||||
c1 = c_cast + j * cstep_c; \
|
||||
\
|
||||
n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \
|
||||
\
|
||||
@@ -533,42 +490,20 @@ void PASTEMAC(ch,varname) \
|
||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
||||
\
|
||||
/* Handle interior and edge cases separately. */ \
|
||||
if ( m_cur == MR && n_cur == NR ) \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
one, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
zero, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
/* Add the result to the edge of C. */ \
|
||||
PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
c11, rs_c, cs_c ); \
|
||||
} \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
m_cur, \
|
||||
n_cur, \
|
||||
k, \
|
||||
alpha_cast, \
|
||||
a1, \
|
||||
b1, \
|
||||
one, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
|
||||
@@ -40,27 +40,30 @@ cntl_t* bli_trsm_cntl_create
|
||||
rntm_t* rntm,
|
||||
side_t side,
|
||||
pack_t schema_a,
|
||||
pack_t schema_b
|
||||
pack_t schema_b,
|
||||
void_fp ker
|
||||
)
|
||||
{
|
||||
if ( bli_is_left( side ) )
|
||||
return bli_trsm_l_cntl_create( rntm, schema_a, schema_b );
|
||||
return bli_trsm_l_cntl_create( rntm, schema_a, schema_b, ker );
|
||||
else
|
||||
return bli_trsm_r_cntl_create( rntm, schema_a, schema_b );
|
||||
return bli_trsm_r_cntl_create( rntm, schema_a, schema_b, ker );
|
||||
}
|
||||
|
||||
cntl_t* bli_trsm_l_cntl_create
|
||||
(
|
||||
rntm_t* rntm,
|
||||
pack_t schema_a,
|
||||
pack_t schema_b
|
||||
pack_t schema_b,
|
||||
void_fp ker
|
||||
)
|
||||
{
|
||||
void_fp macro_kernel_p;
|
||||
|
||||
// Use the function pointer to the macrokernels that use slab
|
||||
// assignment of micropanels to threads in the jr and ir loops.
|
||||
// Set the default macrokernel. If a non-NULL kernel function pointer is
|
||||
// passed in, we use that instead.
|
||||
macro_kernel_p = bli_trsm_xx_ker_var2;
|
||||
if ( ker ) macro_kernel_p = ker;
|
||||
|
||||
const opid_t family = BLIS_TRSM;
|
||||
|
||||
@@ -202,11 +205,15 @@ cntl_t* bli_trsm_r_cntl_create
|
||||
(
|
||||
rntm_t* rntm,
|
||||
pack_t schema_a,
|
||||
pack_t schema_b
|
||||
pack_t schema_b,
|
||||
void_fp ker
|
||||
)
|
||||
{
|
||||
// NOTE: trsm macrokernels are presently disabled for right-side execution.
|
||||
// Set the default macrokernel. If a non-NULL kernel function pointer is
|
||||
// passed in, we use that instead.
|
||||
void_fp macro_kernel_p = bli_trsm_xx_ker_var2;
|
||||
if ( ker ) macro_kernel_p = ker;
|
||||
|
||||
const opid_t family = BLIS_TRSM;
|
||||
|
||||
|
||||
@@ -38,21 +38,24 @@ cntl_t* bli_trsm_cntl_create
|
||||
rntm_t* rntm,
|
||||
side_t side,
|
||||
pack_t schema_a,
|
||||
pack_t schema_b
|
||||
pack_t schema_b,
|
||||
void_fp ker
|
||||
);
|
||||
|
||||
cntl_t* bli_trsm_l_cntl_create
|
||||
(
|
||||
rntm_t* rntm,
|
||||
pack_t schema_a,
|
||||
pack_t schema_b
|
||||
pack_t schema_b,
|
||||
void_fp ker
|
||||
);
|
||||
|
||||
cntl_t* bli_trsm_r_cntl_create
|
||||
(
|
||||
rntm_t* rntm,
|
||||
pack_t schema_a,
|
||||
pack_t schema_b
|
||||
pack_t schema_b,
|
||||
void_fp ker
|
||||
);
|
||||
|
||||
void bli_trsm_cntl_free
|
||||
|
||||
@@ -183,7 +183,6 @@ void PASTEMAC(ch,varname) \
|
||||
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
||||
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
||||
\
|
||||
ctype* restrict zero = PASTEMAC(ch,0); \
|
||||
ctype* restrict minus_one = PASTEMAC(ch,m1); \
|
||||
ctype* restrict a_cast = a; \
|
||||
ctype* restrict b_cast = b; \
|
||||
@@ -470,43 +469,20 @@ void PASTEMAC(ch,varname) \
|
||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
||||
\
|
||||
/* Handle interior and edge cases separately. */ \
|
||||
if ( m_cur == MR && n_cur == NR ) \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
minus_one, \
|
||||
a1, \
|
||||
b1, \
|
||||
alpha2_cast, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
minus_one, \
|
||||
a1, \
|
||||
b1, \
|
||||
zero, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
/* Add the result to the edge of C. */ \
|
||||
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
alpha2_cast, \
|
||||
c11, rs_c, cs_c ); \
|
||||
} \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
m_cur, \
|
||||
n_cur, \
|
||||
k, \
|
||||
minus_one, \
|
||||
a1, \
|
||||
b1, \
|
||||
alpha2_cast, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
a1 += rstep_a; \
|
||||
} \
|
||||
|
||||
@@ -183,7 +183,6 @@ void PASTEMAC(ch,varname) \
|
||||
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
||||
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
||||
\
|
||||
ctype* restrict zero = PASTEMAC(ch,0); \
|
||||
ctype* restrict minus_one = PASTEMAC(ch,m1); \
|
||||
ctype* restrict a_cast = a; \
|
||||
ctype* restrict b_cast = b; \
|
||||
@@ -480,43 +479,20 @@ void PASTEMAC(ch,varname) \
|
||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
||||
\
|
||||
/* Handle interior and edge cases separately. */ \
|
||||
if ( m_cur == MR && n_cur == NR ) \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
minus_one, \
|
||||
a1, \
|
||||
b1, \
|
||||
alpha2_cast, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
minus_one, \
|
||||
a1, \
|
||||
b1, \
|
||||
zero, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
/* Add the result to the edge of C. */ \
|
||||
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
alpha2_cast, \
|
||||
c11, rs_c, cs_c ); \
|
||||
} \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
m_cur, \
|
||||
n_cur, \
|
||||
k, \
|
||||
minus_one, \
|
||||
a1, \
|
||||
b1, \
|
||||
alpha2_cast, \
|
||||
c11, rs_c, cs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
a1 += rstep_a; \
|
||||
} \
|
||||
|
||||
@@ -188,7 +188,6 @@ void PASTEMAC(ch,varname) \
|
||||
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
||||
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
||||
\
|
||||
ctype* restrict zero = PASTEMAC(ch,0); \
|
||||
ctype* restrict minus_one = PASTEMAC(ch,m1); \
|
||||
ctype* restrict a_cast = a; \
|
||||
ctype* restrict b_cast = b; \
|
||||
@@ -499,43 +498,20 @@ void PASTEMAC(ch,varname) \
|
||||
bli_auxinfo_set_next_a( b2, &aux ); \
|
||||
bli_auxinfo_set_next_b( a2, &aux ); \
|
||||
\
|
||||
/* Handle interior and edge cases separately. */ \
|
||||
if ( m_cur == MR && n_cur == NR ) \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
minus_one, \
|
||||
b1, \
|
||||
a1, \
|
||||
alpha2_cast, \
|
||||
c11, cs_c, rs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
minus_one, \
|
||||
b1, \
|
||||
a1, \
|
||||
zero, \
|
||||
ct, cs_ct, rs_ct, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
/* Add the result to the edge of C. */ \
|
||||
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
alpha2_cast, \
|
||||
c11, rs_c, cs_c ); \
|
||||
} \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
m_cur, \
|
||||
n_cur, \
|
||||
k, \
|
||||
minus_one, \
|
||||
b1, \
|
||||
a1, \
|
||||
alpha2_cast, \
|
||||
c11, cs_c, rs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
\
|
||||
a1 += rstep_a; \
|
||||
|
||||
@@ -188,7 +188,6 @@ void PASTEMAC(ch,varname) \
|
||||
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
||||
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
||||
\
|
||||
ctype* restrict zero = PASTEMAC(ch,0); \
|
||||
ctype* restrict minus_one = PASTEMAC(ch,m1); \
|
||||
ctype* restrict a_cast = a; \
|
||||
ctype* restrict b_cast = b; \
|
||||
@@ -492,43 +491,20 @@ void PASTEMAC(ch,varname) \
|
||||
bli_auxinfo_set_next_a( b2, &aux ); \
|
||||
bli_auxinfo_set_next_b( a2, &aux ); \
|
||||
\
|
||||
/* Handle interior and edge cases separately. */ \
|
||||
if ( m_cur == MR && n_cur == NR ) \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
minus_one, \
|
||||
b1, \
|
||||
a1, \
|
||||
alpha2_cast, \
|
||||
c11, cs_c, rs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
k, \
|
||||
minus_one, \
|
||||
b1, \
|
||||
a1, \
|
||||
zero, \
|
||||
ct, cs_ct, rs_ct, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
\
|
||||
/* Add the result to the edge of C. */ \
|
||||
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
|
||||
ct, rs_ct, cs_ct, \
|
||||
alpha2_cast, \
|
||||
c11, rs_c, cs_c ); \
|
||||
} \
|
||||
/* Invoke the gemm micro-kernel. */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
m_cur, \
|
||||
n_cur, \
|
||||
k, \
|
||||
minus_one, \
|
||||
b1, \
|
||||
a1, \
|
||||
alpha2_cast, \
|
||||
c11, cs_c, rs_c, \
|
||||
&aux, \
|
||||
cntx \
|
||||
); \
|
||||
} \
|
||||
\
|
||||
a1 += rstep_a; \
|
||||
|
||||
@@ -74,6 +74,15 @@ BLIS_INLINE inc_t bli_auxinfo_ps_b( auxinfo_t* ai )
|
||||
return ai->ps_b;
|
||||
}
|
||||
|
||||
BLIS_INLINE void_fp bli_auxinfo_ukr( auxinfo_t* ai )
|
||||
{
|
||||
return ai->ukr;
|
||||
}
|
||||
BLIS_INLINE void* bli_auxinfo_params( auxinfo_t* ai )
|
||||
{
|
||||
return ai->params;
|
||||
}
|
||||
|
||||
|
||||
// auxinfo_t field modification
|
||||
|
||||
@@ -118,5 +127,14 @@ BLIS_INLINE void bli_auxinfo_set_ps_b( inc_t ps, auxinfo_t* ai )
|
||||
ai->ps_b = ps;
|
||||
}
|
||||
|
||||
#endif
|
||||
BLIS_INLINE void bli_auxinfo_set_ukr( void_fp ukr, auxinfo_t* ai )
|
||||
{
|
||||
ai->ukr = ukr;
|
||||
}
|
||||
BLIS_INLINE void bli_auxinfo_set_params( void* params, auxinfo_t* ai )
|
||||
{
|
||||
ai->params = params;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
109
frame/include/bli_edge_case_macro_defs.h
Normal file
109
frame/include/bli_edge_case_macro_defs.h
Normal file
@@ -0,0 +1,109 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2021, The University of Texas at Austin
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#ifndef BLIS_EDGE_CASE_MACRO_DEFS_H
|
||||
#define BLIS_EDGE_CASE_MACRO_DEFS_H
|
||||
|
||||
|
||||
// Helper macros for edge-case handling within gemm microkernels.
|
||||
|
||||
#define GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major) \
|
||||
\
|
||||
PASTEMAC(ch,ctype)* restrict _beta = beta; \
|
||||
PASTEMAC(ch,ctype)* restrict _c = c; \
|
||||
const inc_t _rs_c = rs_c; \
|
||||
const inc_t _cs_c = cs_c; \
|
||||
PASTEMAC(ch,ctype) _ct[ BLIS_STACK_BUF_MAX_SIZE / sizeof( PASTEMAC(ch,type) ) ] \
|
||||
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
|
||||
const inc_t _rs_ct = row_major ? nr : 1; \
|
||||
const inc_t _cs_ct = row_major ? 1 : mr;
|
||||
|
||||
#define GEMM_UKR_SETUP_CT_POST(ch) \
|
||||
\
|
||||
PASTEMAC(ch,ctype) _zero; \
|
||||
PASTEMAC(ch,set0s)( _zero ); \
|
||||
\
|
||||
if ( _use_ct ) \
|
||||
{ \
|
||||
c = _ct; \
|
||||
rs_c = _rs_ct; \
|
||||
cs_c = _cs_ct; \
|
||||
beta = &_zero; \
|
||||
}
|
||||
|
||||
#define GEMM_UKR_SETUP_CT(ch,mr,nr,row_major) \
|
||||
\
|
||||
GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \
|
||||
const bool _use_ct = ( row_major ? cs_c != 1 : rs_c != 1 ) || \
|
||||
m != mr || n != nr; \
|
||||
GEMM_UKR_SETUP_CT_POST(ch);
|
||||
|
||||
#define GEMM_UKR_SETUP_CT_AMBI(ch,mr,nr,row_major) \
|
||||
\
|
||||
GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \
|
||||
const bool _use_ct = ( cs_c != 1 && rs_c != 1 ) || \
|
||||
m != mr || n != nr; \
|
||||
GEMM_UKR_SETUP_CT_POST(ch);
|
||||
|
||||
#define GEMM_UKR_SETUP_CT_ANY(ch,mr,nr,row_major) \
|
||||
\
|
||||
GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \
|
||||
const bool _use_ct = m != mr || n != nr; \
|
||||
GEMM_UKR_SETUP_CT_POST(ch);
|
||||
|
||||
#define GEMM_UKR_SETUP_CT_ALIGNED(ch,mr,nr,row_major,alignment) \
|
||||
\
|
||||
GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \
|
||||
const bool _use_ct = ( row_major ? cs_c != 1 : rs_c != 1 ) || \
|
||||
m != mr || n != nr || \
|
||||
( (uintptr_t)_c % alignment ) || \
|
||||
( ( ( row_major ? _rs_c : _cs_c )*sizeof( PASTEMAC(ch,ctype) ) ) % alignment ); \
|
||||
GEMM_UKR_SETUP_CT_POST(ch);
|
||||
|
||||
#define GEMM_UKR_FLUSH_CT(ch) \
|
||||
\
|
||||
if ( _use_ct ) \
|
||||
{ \
|
||||
PASTEMAC(ch,xpbys_mxn) \
|
||||
( \
|
||||
m, n, \
|
||||
_ct, _rs_ct, _cs_ct, \
|
||||
_beta, \
|
||||
_c, _rs_c, _cs_c \
|
||||
); \
|
||||
} \
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
@@ -98,6 +98,7 @@
|
||||
#include "bli_gentprot_macro_defs.h"
|
||||
|
||||
#include "bli_misc_macro_defs.h"
|
||||
#include "bli_edge_case_macro_defs.h"
|
||||
#include "bli_param_macro_defs.h"
|
||||
#include "bli_obj_macro_defs.h"
|
||||
#include "bli_complex_macro_defs.h"
|
||||
|
||||
@@ -1144,6 +1144,13 @@ typedef struct
|
||||
inc_t ps_a;
|
||||
inc_t ps_b;
|
||||
|
||||
// The type to convert to on output.
|
||||
//num_t dt_on_output;
|
||||
|
||||
// (Virtual) microkernel address and additional parameters.
|
||||
void_fp ukr;
|
||||
void* params;
|
||||
|
||||
} auxinfo_t;
|
||||
|
||||
|
||||
|
||||
@@ -42,9 +42,13 @@
|
||||
// 2vx10 microkernels.
|
||||
#include "armsve_asm_2vx10cmplx.h"
|
||||
|
||||
#include "arm_sve.h"
|
||||
|
||||
void bli_cgemm_armsve_asm_2vx10_unindexed
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
scomplex* restrict alpha,
|
||||
scomplex* restrict a,
|
||||
scomplex* restrict b,
|
||||
@@ -59,12 +63,15 @@ void bli_cgemm_armsve_asm_2vx10_unindexed
|
||||
|
||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||
// different size than is expected by load instructions.
|
||||
uint64_t k_mker = k0 / 4;
|
||||
uint64_t k_left = k0 % 4;
|
||||
uint64_t k_mker = k / 4;
|
||||
uint64_t k_left = k % 4;
|
||||
uint64_t rs_c = rs_c0;
|
||||
uint64_t cs_c = cs_c0;
|
||||
uint64_t info = 0;
|
||||
|
||||
uint64_t mr = svcntw();
|
||||
GEMM_UKR_SETUP_CT( c, mr, 10, false );
|
||||
|
||||
__asm__ volatile (
|
||||
// " ldr x0, %[a] \n\t"
|
||||
// " ldr x1, %[b] \n\t"
|
||||
@@ -310,5 +317,7 @@ GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z28,%2,%4,x16)
|
||||
"z24","z25","z26","z27",
|
||||
"z28","z29","z30","z31"
|
||||
);
|
||||
|
||||
GEMM_UKR_FLUSH_CT( c );
|
||||
}
|
||||
|
||||
|
||||
@@ -42,9 +42,13 @@
|
||||
// 2vx10 microkernels.
|
||||
#include "armsve_asm_2vx10.h"
|
||||
|
||||
#include "arm_sve.h"
|
||||
|
||||
void bli_dgemm_armsve_asm_2vx10_unindexed
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
double* restrict alpha,
|
||||
double* restrict a,
|
||||
double* restrict b,
|
||||
@@ -59,11 +63,14 @@ void bli_dgemm_armsve_asm_2vx10_unindexed
|
||||
|
||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||
// different size than is expected by load instructions.
|
||||
uint64_t k_mker = k0 / 4;
|
||||
uint64_t k_left = k0 % 4;
|
||||
uint64_t k_mker = k / 4;
|
||||
uint64_t k_left = k % 4;
|
||||
uint64_t rs_c = rs_c0;
|
||||
uint64_t cs_c = cs_c0;
|
||||
|
||||
uint64_t mr = 2*svcntd();
|
||||
GEMM_UKR_SETUP_CT( d, mr, 10, false );
|
||||
|
||||
__asm__ volatile (
|
||||
" ldr x0, %[a] \n\t"
|
||||
" ldr x1, %[b] \n\t"
|
||||
@@ -324,5 +331,7 @@ GEMM_C_STORE_UKER_G(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z30,p0,p0,x5,x7,x8,x
|
||||
"z24","z25","z26","z27",
|
||||
"z28","z29","z30","z31"
|
||||
);
|
||||
|
||||
GEMM_UKR_FLUSH_CT( d );
|
||||
}
|
||||
|
||||
|
||||
@@ -42,9 +42,13 @@
|
||||
// 2vx10 microkernels.
|
||||
#include "armsve_asm_2vx10.h"
|
||||
|
||||
#include "arm_sve.h"
|
||||
|
||||
void bli_sgemm_armsve_asm_2vx10_unindexed
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
float* restrict alpha,
|
||||
float* restrict a,
|
||||
float* restrict b,
|
||||
@@ -59,11 +63,14 @@ void bli_sgemm_armsve_asm_2vx10_unindexed
|
||||
|
||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||
// different size than is expected by load instructions.
|
||||
uint64_t k_mker = k0 / 4;
|
||||
uint64_t k_left = k0 % 4;
|
||||
uint64_t k_mker = k / 4;
|
||||
uint64_t k_left = k % 4;
|
||||
uint64_t rs_c = rs_c0;
|
||||
uint64_t cs_c = cs_c0;
|
||||
|
||||
uint64_t mr = 2*svcntw();
|
||||
GEMM_UKR_SETUP_CT( s, mr, 10, false );
|
||||
|
||||
__asm__ volatile (
|
||||
" ldr x0, %[a] \n\t"
|
||||
" ldr x1, %[b] \n\t"
|
||||
@@ -310,5 +317,7 @@ GEMM_C_STORE_UKER_G(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z30,p0,p0,x5,x7,x8,x
|
||||
"z24","z25","z26","z27",
|
||||
"z28","z29","z30","z31"
|
||||
);
|
||||
|
||||
GEMM_UKR_FLUSH_CT( s );
|
||||
}
|
||||
|
||||
|
||||
@@ -42,9 +42,13 @@
|
||||
// 2vx10 microkernels.
|
||||
#include "armsve_asm_2vx10cmplx.h"
|
||||
|
||||
#include "arm_sve.h"
|
||||
|
||||
void bli_zgemm_armsve_asm_2vx10_unindexed
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
dcomplex* restrict alpha,
|
||||
dcomplex* restrict a,
|
||||
dcomplex* restrict b,
|
||||
@@ -59,12 +63,15 @@ void bli_zgemm_armsve_asm_2vx10_unindexed
|
||||
|
||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||
// different size than is expected by load instructions.
|
||||
uint64_t k_mker = k0 / 4;
|
||||
uint64_t k_left = k0 % 4;
|
||||
uint64_t k_mker = k / 4;
|
||||
uint64_t k_left = k % 4;
|
||||
uint64_t rs_c = rs_c0;
|
||||
uint64_t cs_c = cs_c0;
|
||||
uint64_t info = 0;
|
||||
|
||||
uint64_t mr = svcntd();
|
||||
GEMM_UKR_SETUP_CT( z, mr, 10, false );
|
||||
|
||||
__asm__ volatile (
|
||||
// " ldr x0, %[a] \n\t"
|
||||
// " ldr x1, %[b] \n\t"
|
||||
@@ -309,5 +316,7 @@ GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z28,%2,%4,x16)
|
||||
"z24","z25","z26","z27",
|
||||
"z28","z29","z30","z31"
|
||||
);
|
||||
|
||||
GEMM_UKR_FLUSH_CT( z );
|
||||
}
|
||||
|
||||
|
||||
@@ -42,9 +42,13 @@
|
||||
// 2vx7 microkernels.
|
||||
#include "armsve_asm_2vx7cmplx.h"
|
||||
|
||||
#include "arm_sve.h"
|
||||
|
||||
void bli_zgemm_armsve_asm_2vx7_unindexed
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
dcomplex* restrict alpha,
|
||||
dcomplex* restrict a,
|
||||
dcomplex* restrict b,
|
||||
@@ -59,12 +63,15 @@ void bli_zgemm_armsve_asm_2vx7_unindexed
|
||||
|
||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||
// different size than is expected by load instructions.
|
||||
uint64_t k_mker = k0 / 4;
|
||||
uint64_t k_left = k0 % 4;
|
||||
uint64_t k_mker = k / 4;
|
||||
uint64_t k_left = k % 4;
|
||||
uint64_t rs_c = rs_c0;
|
||||
uint64_t cs_c = cs_c0;
|
||||
uint64_t info = 0;
|
||||
|
||||
uint64_t mr = svcntd();
|
||||
GEMM_UKR_SETUP_CT( z, mr, 7, false );
|
||||
|
||||
__asm__ volatile (
|
||||
// " ldr x0, %[a] \n\t"
|
||||
// " ldr x1, %[b] \n\t"
|
||||
@@ -261,6 +268,8 @@ GEMM_CCMPLX_STORE_COL7_G(z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27
|
||||
"z24","z25","z26","z27",
|
||||
"z28","z29","z30","z31"
|
||||
);
|
||||
|
||||
GEMM_UKR_FLUSH_CT( z );
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -42,9 +42,13 @@
|
||||
// 2vx8 microkernels.
|
||||
#include "armsve_asm_2vx8cmplx.h"
|
||||
|
||||
#include "arm_sve.h"
|
||||
|
||||
void bli_zgemm_armsve_asm_2vx8_unindexed
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
dcomplex* restrict alpha,
|
||||
dcomplex* restrict a,
|
||||
dcomplex* restrict b,
|
||||
@@ -59,12 +63,15 @@ void bli_zgemm_armsve_asm_2vx8_unindexed
|
||||
|
||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||
// different size than is expected by load instructions.
|
||||
uint64_t k_mker = k0 / 6;
|
||||
uint64_t k_left = k0 % 6;
|
||||
uint64_t k_mker = k / 6;
|
||||
uint64_t k_left = k % 6;
|
||||
uint64_t rs_c = rs_c0;
|
||||
uint64_t cs_c = cs_c0;
|
||||
uint64_t info = 0;
|
||||
|
||||
uint64_t mr = svcntd();
|
||||
GEMM_UKR_SETUP_CT( z, mr, 8, false );
|
||||
|
||||
__asm__ volatile (
|
||||
// " ldr x0, %[a] \n\t"
|
||||
// " ldr x1, %[b] \n\t"
|
||||
@@ -286,5 +293,7 @@ GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z16,%2,%4,x16)
|
||||
"z24","z25","z26","z27",
|
||||
"z28","z29","z30","z31"
|
||||
);
|
||||
|
||||
GEMM_UKR_FLUSH_CT( z );
|
||||
}
|
||||
|
||||
|
||||
@@ -48,23 +48,23 @@ void bli_sgemm_armv7a_ker_4x4
|
||||
|
||||
void bli_sgemm_armv7a_asm_4x4
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
float* restrict alpha,
|
||||
float* restrict a,
|
||||
float* restrict b,
|
||||
float* restrict beta,
|
||||
float* restrict c, inc_t rs_c0, inc_t cs_c0,
|
||||
float* restrict c, inc_t rs_c, inc_t cs_c,
|
||||
auxinfo_t* restrict data,
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||
// different size than is expected by load instructions.
|
||||
uint32_t k = k0;
|
||||
uint32_t rs_c = rs_c0;
|
||||
uint32_t cs_c = cs_c0;
|
||||
|
||||
GEMM_UKR_SETUP_CT_ANY( s, 4, 4, false );
|
||||
bli_sgemm_armv7a_ker_4x4( k, alpha, a, b, beta, c, rs_c, cs_c, data );
|
||||
GEMM_UKR_FLUSH_CT( s );
|
||||
}
|
||||
|
||||
|
||||
@@ -83,23 +83,23 @@ void bli_dgemm_armv7a_ker_4x4
|
||||
|
||||
void bli_dgemm_armv7a_asm_4x4
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
double* restrict alpha,
|
||||
double* restrict a,
|
||||
double* restrict b,
|
||||
double* restrict beta,
|
||||
double* restrict c, inc_t rs_c0, inc_t cs_c0,
|
||||
double* restrict c, inc_t rs_c, inc_t cs_c,
|
||||
auxinfo_t* restrict data,
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||
// different size than is expected by load instructions.
|
||||
uint32_t k = k0;
|
||||
uint32_t rs_c = rs_c0;
|
||||
uint32_t cs_c = cs_c0;
|
||||
|
||||
GEMM_UKR_SETUP_CT_ANY( d, 4, 4, false );
|
||||
bli_dgemm_armv7a_ker_4x4( k, alpha, a, b, beta, c, rs_c, cs_c, data );
|
||||
GEMM_UKR_FLUSH_CT( d );
|
||||
}
|
||||
|
||||
|
||||
@@ -118,23 +118,23 @@ void bli_cgemm_armv7a_ker_2x2
|
||||
|
||||
void bli_cgemm_armv7a_asm_2x2
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
scomplex* restrict alpha,
|
||||
scomplex* restrict a,
|
||||
scomplex* restrict b,
|
||||
scomplex* restrict beta,
|
||||
scomplex* restrict c, inc_t rs_c0, inc_t cs_c0,
|
||||
scomplex* restrict c, inc_t rs_c, inc_t cs_c,
|
||||
auxinfo_t* restrict data,
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||
// different size than is expected by load instructions.
|
||||
uint32_t k = k0;
|
||||
uint32_t rs_c = rs_c0;
|
||||
uint32_t cs_c = cs_c0;
|
||||
|
||||
GEMM_UKR_SETUP_CT_ANY( c, 2, 2, false );
|
||||
bli_cgemm_armv7a_ker_2x2( k, alpha, a, b, beta, c, rs_c, cs_c, data );
|
||||
GEMM_UKR_FLUSH_CT( c );
|
||||
}
|
||||
|
||||
|
||||
@@ -153,22 +153,22 @@ void bli_zgemm_armv7a_ker_2x2
|
||||
|
||||
void bli_zgemm_armv7a_asm_2x2
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
dcomplex* restrict alpha,
|
||||
dcomplex* restrict a,
|
||||
dcomplex* restrict b,
|
||||
dcomplex* restrict beta,
|
||||
dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0,
|
||||
dcomplex* restrict c, inc_t rs_c, inc_t cs_c,
|
||||
auxinfo_t* restrict data,
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||
// different size than is expected by load instructions.
|
||||
uint32_t k = k0;
|
||||
uint32_t rs_c = rs_c0;
|
||||
uint32_t cs_c = cs_c0;
|
||||
|
||||
GEMM_UKR_SETUP_CT_ANY( z, 2, 2, false );
|
||||
bli_zgemm_armv7a_ker_2x2( k, alpha, a, b, beta, c, rs_c, cs_c, data );
|
||||
GEMM_UKR_FLUSH_CT( z );
|
||||
}
|
||||
|
||||
|
||||
@@ -37,7 +37,9 @@
|
||||
|
||||
void bli_sgemm_armv7a_int_4x4
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
float* restrict alpha,
|
||||
float* restrict a,
|
||||
float* restrict b,
|
||||
@@ -49,12 +51,14 @@ void bli_sgemm_armv7a_int_4x4
|
||||
{
|
||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||
// different size than is expected by load instructions.
|
||||
uint32_t k_iter = k0 / 4;
|
||||
uint32_t k_left = k0 % 4;
|
||||
uint32_t k_iter = k / 4;
|
||||
uint32_t k_left = k % 4;
|
||||
uint32_t rs_c = rs_c0;
|
||||
uint32_t cs_c = cs_c0;
|
||||
uint32_t i;
|
||||
|
||||
GEMM_UKR_SETUP_CT( s, 4, 4, false );
|
||||
|
||||
void* a_next = bli_auxinfo_next_a( data );
|
||||
void* b_next = bli_auxinfo_next_b( data );
|
||||
|
||||
@@ -82,47 +86,17 @@ void bli_sgemm_armv7a_int_4x4
|
||||
|
||||
if ( *beta != 0.0F )
|
||||
{
|
||||
if ( rs_c == 1 )
|
||||
{
|
||||
// Load column 0
|
||||
cv0 = vld1q_f32( c + 0*rs_c + 0*cs_c );
|
||||
// Load column 0
|
||||
cv0 = vld1q_f32( c + 0*cs_c );
|
||||
|
||||
// Load column 1
|
||||
cv1 = vld1q_f32( c + 0*rs_c + 1*cs_c );
|
||||
// Load column 1
|
||||
cv1 = vld1q_f32( c + 1*cs_c );
|
||||
|
||||
// Load column 2
|
||||
cv2 = vld1q_f32( c + 0*rs_c + 2*cs_c );
|
||||
// Load column 2
|
||||
cv2 = vld1q_f32( c + 2*cs_c );
|
||||
|
||||
// Load column 3
|
||||
cv3 = vld1q_f32( c + 0*rs_c + 3*cs_c );
|
||||
}
|
||||
else
|
||||
{
|
||||
// Load column 0
|
||||
cv0 = vld1q_lane_f32( c + 0*rs_c + 0*cs_c, cv0, 0);
|
||||
cv0 = vld1q_lane_f32( c + 1*rs_c + 0*cs_c, cv0, 1);
|
||||
cv0 = vld1q_lane_f32( c + 2*rs_c + 0*cs_c, cv0, 2);
|
||||
cv0 = vld1q_lane_f32( c + 3*rs_c + 0*cs_c, cv0, 3);
|
||||
|
||||
// Load column 1
|
||||
cv1 = vld1q_lane_f32( c + 0*rs_c + 1*cs_c, cv1, 0);
|
||||
cv1 = vld1q_lane_f32( c + 1*rs_c + 1*cs_c, cv1, 1);
|
||||
cv1 = vld1q_lane_f32( c + 2*rs_c + 1*cs_c, cv1, 2);
|
||||
cv1 = vld1q_lane_f32( c + 3*rs_c + 1*cs_c, cv1, 3);
|
||||
|
||||
// Load column 2
|
||||
cv2 = vld1q_lane_f32( c + 0*rs_c + 2*cs_c, cv2, 0);
|
||||
cv2 = vld1q_lane_f32( c + 1*rs_c + 2*cs_c, cv2, 1);
|
||||
cv2 = vld1q_lane_f32( c + 2*rs_c + 2*cs_c, cv2, 2);
|
||||
cv2 = vld1q_lane_f32( c + 3*rs_c + 2*cs_c, cv2, 3);
|
||||
|
||||
// Load column 3
|
||||
cv3 = vld1q_lane_f32( c + 0*rs_c + 3*cs_c, cv3, 0);
|
||||
cv3 = vld1q_lane_f32( c + 1*rs_c + 3*cs_c, cv3, 1);
|
||||
cv3 = vld1q_lane_f32( c + 2*rs_c + 3*cs_c, cv3, 2);
|
||||
cv3 = vld1q_lane_f32( c + 3*rs_c + 3*cs_c, cv3, 3);
|
||||
|
||||
}
|
||||
// Load column 3
|
||||
cv3 = vld1q_f32( c + 3*cs_c );
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -255,47 +229,22 @@ void bli_sgemm_armv7a_int_4x4
|
||||
cv3 = vmlaq_f32( cv3, abv3, alphav );
|
||||
}
|
||||
|
||||
if ( rs_c == 1 )
|
||||
{
|
||||
// Store column 0
|
||||
vst1q_f32( c + 0*rs_c + 0*cs_c, cv0 );
|
||||
// Store column 1
|
||||
vst1q_f32( c + 0*rs_c + 1*cs_c, cv1 );
|
||||
// Store column 2
|
||||
vst1q_f32( c + 0*rs_c + 2*cs_c, cv2 );
|
||||
// Store column 3
|
||||
vst1q_f32( c + 0*rs_c + 3*cs_c, cv3 );
|
||||
}
|
||||
else
|
||||
{
|
||||
// Store column 0
|
||||
vst1q_lane_f32( c + 0*rs_c + 0*cs_c, cv0, 0);
|
||||
vst1q_lane_f32( c + 1*rs_c + 0*cs_c, cv0, 1);
|
||||
vst1q_lane_f32( c + 2*rs_c + 0*cs_c, cv0, 2);
|
||||
vst1q_lane_f32( c + 3*rs_c + 0*cs_c, cv0, 3);
|
||||
// Store column 0
|
||||
vst1q_f32( c + 0*cs_c, cv0 );
|
||||
// Store column 1
|
||||
vst1q_f32( c + 1*cs_c, cv1 );
|
||||
// Store column 2
|
||||
vst1q_f32( c + 2*cs_c, cv2 );
|
||||
// Store column 3
|
||||
vst1q_f32( c + 3*cs_c, cv3 );
|
||||
|
||||
// Store column 1
|
||||
vst1q_lane_f32( c + 0*rs_c + 1*cs_c, cv1, 0);
|
||||
vst1q_lane_f32( c + 1*rs_c + 1*cs_c, cv1, 1);
|
||||
vst1q_lane_f32( c + 2*rs_c + 1*cs_c, cv1, 2);
|
||||
vst1q_lane_f32( c + 3*rs_c + 1*cs_c, cv1, 3);
|
||||
|
||||
// Store column 2
|
||||
vst1q_lane_f32( c + 0*rs_c + 2*cs_c, cv2, 0);
|
||||
vst1q_lane_f32( c + 1*rs_c + 2*cs_c, cv2, 1);
|
||||
vst1q_lane_f32( c + 2*rs_c + 2*cs_c, cv2, 2);
|
||||
vst1q_lane_f32( c + 3*rs_c + 2*cs_c, cv2, 3);
|
||||
|
||||
// Store column 3
|
||||
vst1q_lane_f32( c + 0*rs_c + 3*cs_c, cv3, 0);
|
||||
vst1q_lane_f32( c + 1*rs_c + 3*cs_c, cv3, 1);
|
||||
vst1q_lane_f32( c + 2*rs_c + 3*cs_c, cv3, 2);
|
||||
vst1q_lane_f32( c + 3*rs_c + 3*cs_c, cv3, 3);
|
||||
}
|
||||
GEMM_UKR_FLUSH_CT( s );
|
||||
}
|
||||
|
||||
void bli_dgemm_armv7a_int_4x4
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
double* restrict alpha,
|
||||
double* restrict a,
|
||||
@@ -314,6 +263,8 @@ void bli_dgemm_armv7a_int_4x4
|
||||
uint32_t cs_c = cs_c0;
|
||||
uint32_t i;
|
||||
|
||||
GEMM_UKR_SETUP_CT_ANY( d, 4, 4, false );
|
||||
|
||||
//void* a_next = bli_auxinfo_next_a( data );
|
||||
//void* b_next = bli_auxinfo_next_b( data );
|
||||
|
||||
@@ -568,5 +519,7 @@ void bli_dgemm_armv7a_int_4x4
|
||||
*c23 += ab23 * *alpha;
|
||||
*c33 += ab33 * *alpha;
|
||||
}
|
||||
|
||||
GEMM_UKR_FLUSH_CT( d );
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -56,6 +56,8 @@
|
||||
|
||||
void bli_dgemm_bgq_int_8x8
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
double* restrict alpha,
|
||||
double* restrict a,
|
||||
@@ -66,6 +68,8 @@ void bli_dgemm_bgq_int_8x8
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
GEMM_UKR_SETUP_CT_ANY( d, 8, 8, false );
|
||||
|
||||
//Registers for storing C.
|
||||
//4 4x4 subblocks of C, c00, c01, c10, c11
|
||||
//4 registers per subblock: a, b, c, d
|
||||
@@ -201,6 +205,8 @@ void bli_dgemm_bgq_int_8x8
|
||||
UPDATE( AB, c, 0 );
|
||||
AB = vec_perm( c11d, c11d, pattern );
|
||||
UPDATE( AB, c, 4 );
|
||||
|
||||
GEMM_UKR_FLUSH_CT( d );
|
||||
}
|
||||
|
||||
void printvec(vector4double v)
|
||||
@@ -214,6 +220,8 @@ void printvec(vector4double v)
|
||||
|
||||
void bli_zgemm_bgq_int_4x4
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
dcomplex* restrict alpha,
|
||||
dcomplex* restrict a,
|
||||
@@ -224,6 +232,8 @@ void bli_zgemm_bgq_int_4x4
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
GEMM_UKR_SETUP_CT_ANY( z, 4, 4, false );
|
||||
|
||||
double* a_d = ( double* )a;
|
||||
double* b_d = ( double* )b;
|
||||
double* c_d = ( double* )c;
|
||||
@@ -368,4 +378,6 @@ void bli_zgemm_bgq_int_4x4
|
||||
c_d += 2*cs_c;
|
||||
ZUPDATE( c03a, c03b, c_d, 0 );
|
||||
ZUPDATE( c13a, c13b, c_d, 4 );
|
||||
|
||||
GEMM_UKR_FLUSH_CT( z );
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -256,6 +256,8 @@ extern int offsets[16];
|
||||
//#define LOOPMON
|
||||
void bli_dgemm_knc_asm_30x8
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
double* restrict alpha,
|
||||
double* restrict a,
|
||||
@@ -273,80 +275,82 @@ void bli_dgemm_knc_asm_30x8
|
||||
|
||||
uint64_t k64 = k;
|
||||
|
||||
GEMM_UKR_SETUP_CT( d, 30, 8, true );
|
||||
|
||||
#ifdef MONITORS
|
||||
int toph, topl, both, botl, midl, midh, mid2l, mid2h;
|
||||
#endif
|
||||
#ifdef LOOPMON
|
||||
int tlooph, tloopl, blooph, bloopl;
|
||||
#endif
|
||||
|
||||
|
||||
__asm
|
||||
{
|
||||
#ifdef MONITORS
|
||||
rdtsc
|
||||
mov topl, eax
|
||||
mov toph, edx
|
||||
mov toph, edx
|
||||
#endif
|
||||
vpxord zmm0, zmm0, zmm0
|
||||
vmovaps zmm1, zmm0 //clear out registers
|
||||
vmovaps zmm2, zmm0
|
||||
vmovaps zmm2, zmm0
|
||||
mov rsi, k64 //loop index
|
||||
vmovaps zmm3, zmm0
|
||||
vmovaps zmm3, zmm0
|
||||
|
||||
mov r11, rs_c //load row stride
|
||||
vmovaps zmm4, zmm0
|
||||
vmovaps zmm4, zmm0
|
||||
sal r11, 3 //scale row stride
|
||||
vmovaps zmm5, zmm0
|
||||
vmovaps zmm5, zmm0
|
||||
mov r15, a //load address of a
|
||||
vmovaps zmm6, zmm0
|
||||
vmovaps zmm6, zmm0
|
||||
mov rbx, b //load address of b
|
||||
vmovaps zmm7, zmm0
|
||||
vmovaps zmm7, zmm0
|
||||
|
||||
vmovaps zmm8, zmm0
|
||||
vmovaps zmm8, zmm0
|
||||
lea r10, [r11 + 2*r11 + 0] //r10 has 3 * r11
|
||||
vmovaps zmm9, zmm0
|
||||
vmovaps zmm10, zmm0
|
||||
mov rdi, r11
|
||||
vmovaps zmm11, zmm0
|
||||
vmovaps zmm10, zmm0
|
||||
mov rdi, r11
|
||||
vmovaps zmm11, zmm0
|
||||
sal rdi, 2 //rdi has 4*r11
|
||||
|
||||
vmovaps zmm12, zmm0
|
||||
vmovaps zmm12, zmm0
|
||||
mov rcx, c //load address of c for prefetching
|
||||
vmovaps zmm13, zmm0
|
||||
vmovaps zmm14, zmm0
|
||||
vmovaps zmm13, zmm0
|
||||
vmovaps zmm14, zmm0
|
||||
mov r8, k64
|
||||
vmovaps zmm15, zmm0
|
||||
vmovaps zmm15, zmm0
|
||||
|
||||
vmovaps zmm16, zmm0
|
||||
vmovaps zmm17, zmm0
|
||||
mov r13, L2_PREFETCH_DIST*8*8
|
||||
vmovaps zmm18, zmm0
|
||||
vmovaps zmm18, zmm0
|
||||
mov r14, L2_PREFETCH_DIST*8*32
|
||||
vmovaps zmm19, zmm0
|
||||
vmovaps zmm20, zmm0
|
||||
vmovaps zmm21, zmm0
|
||||
vmovaps zmm22, zmm0
|
||||
vmovaps zmm19, zmm0
|
||||
vmovaps zmm20, zmm0
|
||||
vmovaps zmm21, zmm0
|
||||
vmovaps zmm22, zmm0
|
||||
|
||||
vmovaps zmm23, zmm0
|
||||
vmovaps zmm23, zmm0
|
||||
sub r8, 30 + L2_PREFETCH_DIST //Check if we have over 40 operations to do.
|
||||
vmovaps zmm24, zmm0
|
||||
vmovaps zmm24, zmm0
|
||||
mov r8, 30
|
||||
vmovaps zmm25, zmm0
|
||||
vmovaps zmm25, zmm0
|
||||
mov r9, 8*8 //amount to increment b* by each iteration
|
||||
vmovaps zmm26, zmm0
|
||||
vmovaps zmm26, zmm0
|
||||
mov r12, 32*8 //amount to increment a* by each iteration
|
||||
vmovaps zmm27, zmm0
|
||||
vmovaps zmm28, zmm0
|
||||
vmovaps zmm29, zmm0
|
||||
vmovaps zmm27, zmm0
|
||||
vmovaps zmm28, zmm0
|
||||
vmovaps zmm29, zmm0
|
||||
|
||||
#ifdef MONITORS
|
||||
rdtsc
|
||||
mov midl, eax
|
||||
mov midh, edx
|
||||
mov midh, edx
|
||||
#endif
|
||||
jle CONSIDER_UNDER_40
|
||||
sub rsi, 30 + L2_PREFETCH_DIST
|
||||
|
||||
|
||||
//First 30 iterations
|
||||
LOOPREFECHCL2:
|
||||
ONE_ITER_PC_L2(rcx)
|
||||
@@ -357,26 +361,26 @@ void bli_dgemm_knc_asm_30x8
|
||||
LOOPMAIN:
|
||||
ONE_ITER_MAIN_LOOP(rcx, rsi)
|
||||
jne LOOPMAIN
|
||||
|
||||
|
||||
//Penultimate 22 iterations.
|
||||
//Break these off from the main loop to avoid prefetching extra shit.
|
||||
mov r14, a_next
|
||||
mov r13, b_next
|
||||
sub r14, r15
|
||||
sub r13, rbx
|
||||
|
||||
|
||||
mov rsi, L2_PREFETCH_DIST-10
|
||||
LOOPMAIN2:
|
||||
ONE_ITER_MAIN_LOOP(rcx, rsi)
|
||||
jne LOOPMAIN2
|
||||
|
||||
|
||||
|
||||
|
||||
//Last 10 iterations
|
||||
mov r8, 10
|
||||
LOOPREFETCHCL1:
|
||||
ONE_ITER_PC_L1(rcx)
|
||||
jne LOOPREFETCHCL1
|
||||
|
||||
|
||||
|
||||
jmp POSTACCUM
|
||||
|
||||
@@ -403,14 +407,8 @@ void bli_dgemm_knc_asm_30x8
|
||||
mov r9, c //load address of c for update
|
||||
mov r12, alpha //load address of alpha
|
||||
|
||||
// Check if C is row stride. If not, jump to the slow scattered update
|
||||
mov r14, cs_c
|
||||
dec r14
|
||||
jne SCATTEREDUPDATE
|
||||
|
||||
mov r14, beta
|
||||
vbroadcastsd zmm31, 0[r14]
|
||||
|
||||
vbroadcastsd zmm31, 0[r14]
|
||||
|
||||
vmulpd zmm0, zmm0, 0[r12]{1to8}
|
||||
vmulpd zmm1, zmm1, 0[r12]{1to8}
|
||||
@@ -467,7 +465,7 @@ void bli_dgemm_knc_asm_30x8
|
||||
vmovapd [r9+2*r11+0], zmm14
|
||||
vmovapd [r9+r10+0], zmm15
|
||||
add r9, rdi
|
||||
|
||||
|
||||
vmulpd zmm16, zmm16, 0[r12]{1to8}
|
||||
vmulpd zmm17, zmm17, 0[r12]{1to8}
|
||||
vmulpd zmm18, zmm18, 0[r12]{1to8}
|
||||
@@ -516,47 +514,6 @@ void bli_dgemm_knc_asm_30x8
|
||||
vfmadd231pd zmm29, zmm31, [r9+r11+0]
|
||||
vmovapd [r9+0], zmm28
|
||||
vmovapd [r9+r11+0], zmm29
|
||||
|
||||
jmp END
|
||||
|
||||
SCATTEREDUPDATE:
|
||||
mov r10, offsetPtr
|
||||
vmovapd zmm31, 0[r10]
|
||||
vpbroadcastd zmm30, cs_c
|
||||
mov r13, beta
|
||||
vpmulld zmm30, zmm31, zmm30
|
||||
|
||||
mov ebx, 255
|
||||
UPDATE_C_ROW_SCATTERED(zmm0, 0, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm1, 1, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm2, 2, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm3, 3, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm4, 4, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm5, 5, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm6, 6, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm7, 7, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm8, 8, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm9, 9, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm10, 10, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm11, 11, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm12, 12, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm13, 13, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm14, 14, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm15, 15, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm16, 16, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm17, 17, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm18, 18, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm19, 19, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm20, 20, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm21, 21, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm22, 22, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm23, 23, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm24, 24, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm25, 25, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm26, 26, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm27, 27, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm28, 28, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm29, 29, r9)
|
||||
|
||||
END:
|
||||
#ifdef MONITORS
|
||||
@@ -566,6 +523,8 @@ void bli_dgemm_knc_asm_30x8
|
||||
#endif
|
||||
}
|
||||
|
||||
GEMM_UKR_FLUSH_CT( d );
|
||||
|
||||
#ifdef LOOPMON
|
||||
printf("looptime = \t%d\n", bloopl - tloopl);
|
||||
#endif
|
||||
|
||||
@@ -256,6 +256,8 @@ int offsets[16] __attribute__((aligned(0x1000))) = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9
|
||||
//#define LOOPMON
|
||||
void bli_sgemm_knc_asm_30x16
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
float* restrict alpha,
|
||||
float* restrict a,
|
||||
@@ -273,80 +275,82 @@ void bli_sgemm_knc_asm_30x16
|
||||
|
||||
uint64_t k64 = k;
|
||||
|
||||
GEMM_UKR_SETUP_CT( s, 30, 16, true );
|
||||
|
||||
#ifdef MONITORS
|
||||
int toph, topl, both, botl, midl, midh, mid2l, mid2h;
|
||||
#endif
|
||||
#ifdef LOOPMON
|
||||
int tlooph, tloopl, blooph, bloopl;
|
||||
#endif
|
||||
|
||||
|
||||
__asm
|
||||
{
|
||||
#ifdef MONITORS
|
||||
rdtsc
|
||||
mov topl, eax
|
||||
mov toph, edx
|
||||
mov toph, edx
|
||||
#endif
|
||||
vpxord zmm0, zmm0, zmm0
|
||||
vmovaps zmm1, zmm0 //clear out registers
|
||||
vmovaps zmm2, zmm0
|
||||
vmovaps zmm2, zmm0
|
||||
mov rsi, k64 //loop index
|
||||
vmovaps zmm3, zmm0
|
||||
vmovaps zmm3, zmm0
|
||||
|
||||
mov r11, rs_c //load row stride
|
||||
vmovaps zmm4, zmm0
|
||||
vmovaps zmm4, zmm0
|
||||
sal r11, 2 //scale row stride
|
||||
vmovaps zmm5, zmm0
|
||||
vmovaps zmm5, zmm0
|
||||
mov r15, a //load address of a
|
||||
vmovaps zmm6, zmm0
|
||||
vmovaps zmm6, zmm0
|
||||
mov rbx, b //load address of b
|
||||
vmovaps zmm7, zmm0
|
||||
vmovaps zmm7, zmm0
|
||||
|
||||
vmovaps zmm8, zmm0
|
||||
vmovaps zmm8, zmm0
|
||||
lea r10, [r11 + 2*r11 + 0] //r10 has 3 * r11
|
||||
vmovaps zmm9, zmm0
|
||||
vmovaps zmm10, zmm0
|
||||
mov rdi, r11
|
||||
vmovaps zmm11, zmm0
|
||||
vmovaps zmm10, zmm0
|
||||
mov rdi, r11
|
||||
vmovaps zmm11, zmm0
|
||||
sal rdi, 2 //rdi has 4*r11
|
||||
|
||||
vmovaps zmm12, zmm0
|
||||
vmovaps zmm12, zmm0
|
||||
mov rcx, c //load address of c for prefetching
|
||||
vmovaps zmm13, zmm0
|
||||
vmovaps zmm14, zmm0
|
||||
vmovaps zmm13, zmm0
|
||||
vmovaps zmm14, zmm0
|
||||
mov r8, k64
|
||||
vmovaps zmm15, zmm0
|
||||
vmovaps zmm15, zmm0
|
||||
|
||||
vmovaps zmm16, zmm0
|
||||
vmovaps zmm17, zmm0
|
||||
mov r13, L2_PREFETCH_DIST*4*16
|
||||
vmovaps zmm18, zmm0
|
||||
vmovaps zmm18, zmm0
|
||||
mov r14, L2_PREFETCH_DIST*4*32
|
||||
vmovaps zmm19, zmm0
|
||||
vmovaps zmm20, zmm0
|
||||
vmovaps zmm21, zmm0
|
||||
vmovaps zmm22, zmm0
|
||||
vmovaps zmm19, zmm0
|
||||
vmovaps zmm20, zmm0
|
||||
vmovaps zmm21, zmm0
|
||||
vmovaps zmm22, zmm0
|
||||
|
||||
vmovaps zmm23, zmm0
|
||||
vmovaps zmm23, zmm0
|
||||
sub r8, 30 + L2_PREFETCH_DIST //Check if we have over 40 operations to do.
|
||||
vmovaps zmm24, zmm0
|
||||
vmovaps zmm24, zmm0
|
||||
mov r8, 30
|
||||
vmovaps zmm25, zmm0
|
||||
vmovaps zmm25, zmm0
|
||||
mov r9, 16*4 //amount to increment b* by each iteration
|
||||
vmovaps zmm26, zmm0
|
||||
vmovaps zmm26, zmm0
|
||||
mov r12, 32*4 //amount to increment a* by each iteration
|
||||
vmovaps zmm27, zmm0
|
||||
vmovaps zmm28, zmm0
|
||||
vmovaps zmm29, zmm0
|
||||
vmovaps zmm27, zmm0
|
||||
vmovaps zmm28, zmm0
|
||||
vmovaps zmm29, zmm0
|
||||
|
||||
#ifdef MONITORS
|
||||
rdtsc
|
||||
mov midl, eax
|
||||
mov midh, edx
|
||||
mov midh, edx
|
||||
#endif
|
||||
jle CONSIDER_UNDER_40
|
||||
sub rsi, 30 + L2_PREFETCH_DIST
|
||||
|
||||
|
||||
//First 30 iterations
|
||||
LOOPREFECHCL2:
|
||||
ONE_ITER_PC_L2(rcx)
|
||||
@@ -357,26 +361,26 @@ void bli_sgemm_knc_asm_30x16
|
||||
LOOPMAIN:
|
||||
ONE_ITER_MAIN_LOOP(rcx, rsi)
|
||||
jne LOOPMAIN
|
||||
|
||||
|
||||
//Penultimate 22 iterations.
|
||||
//Break these off from the main loop to avoid prefetching extra shit.
|
||||
mov r14, a_next
|
||||
mov r13, b_next
|
||||
sub r14, r15
|
||||
sub r13, rbx
|
||||
|
||||
|
||||
mov rsi, L2_PREFETCH_DIST-10
|
||||
LOOPMAIN2:
|
||||
ONE_ITER_MAIN_LOOP(rcx, rsi)
|
||||
jne LOOPMAIN2
|
||||
|
||||
|
||||
|
||||
|
||||
//Last 10 iterations
|
||||
mov r8, 10
|
||||
LOOPREFETCHCL1:
|
||||
ONE_ITER_PC_L1(rcx)
|
||||
jne LOOPREFETCHCL1
|
||||
|
||||
|
||||
|
||||
jmp POSTACCUM
|
||||
|
||||
@@ -384,7 +388,7 @@ void bli_sgemm_knc_asm_30x16
|
||||
//Used when <= 40 iterations
|
||||
CONSIDER_UNDER_40:
|
||||
mov rsi, k64
|
||||
test rsi, rsi
|
||||
test rsi, rsi
|
||||
je POSTACCUM
|
||||
LOOP_UNDER_40:
|
||||
ONE_ITER_MAIN_LOOP(rcx, rsi)
|
||||
@@ -403,13 +407,8 @@ void bli_sgemm_knc_asm_30x16
|
||||
mov r9, c //load address of c for update
|
||||
mov r12, alpha //load address of alpha
|
||||
|
||||
// Check if C is row stride. If not, jump to the slow scattered update
|
||||
mov r14, cs_c
|
||||
dec r14
|
||||
jne SCATTEREDUPDATE
|
||||
|
||||
mov r14, beta
|
||||
vbroadcastss zmm31, 0[r14]
|
||||
vbroadcastss zmm31, 0[r14]
|
||||
|
||||
|
||||
vmulps zmm0, zmm0, 0[r12]{1to16}
|
||||
@@ -467,7 +466,7 @@ void bli_sgemm_knc_asm_30x16
|
||||
vmovaps [r9+2*r11+0], zmm14
|
||||
vmovaps [r9+r10+0], zmm15
|
||||
add r9, rdi
|
||||
|
||||
|
||||
vmulps zmm16, zmm16, 0[r12]{1to16}
|
||||
vmulps zmm17, zmm17, 0[r12]{1to16}
|
||||
vmulps zmm18, zmm18, 0[r12]{1to16}
|
||||
@@ -516,48 +515,6 @@ void bli_sgemm_knc_asm_30x16
|
||||
vfmadd231ps zmm29, zmm31, [r9+r11+0]
|
||||
vmovaps [r9+0], zmm28
|
||||
vmovaps [r9+r11+0], zmm29
|
||||
|
||||
jmp END
|
||||
|
||||
SCATTEREDUPDATE:
|
||||
|
||||
mov r10, offsetPtr
|
||||
vmovaps zmm31, 0[r10]
|
||||
vpbroadcastd zmm30, cs_c
|
||||
mov r13, beta
|
||||
vpmulld zmm30, zmm31, zmm30
|
||||
|
||||
mov ebx, 0xFFFF
|
||||
UPDATE_C_ROW_SCATTERED(zmm0, 0, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm1, 1, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm2, 2, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm3, 3, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm4, 4, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm5, 5, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm6, 6, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm7, 7, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm8, 8, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm9, 9, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm10, 10, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm11, 11, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm12, 12, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm13, 13, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm14, 14, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm15, 15, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm16, 16, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm17, 17, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm18, 18, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm19, 19, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm20, 20, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm21, 21, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm22, 22, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm23, 23, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm24, 24, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm25, 25, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm26, 26, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm27, 27, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm28, 28, r9)
|
||||
UPDATE_C_ROW_SCATTERED(zmm29, 29, r9)
|
||||
|
||||
END:
|
||||
#ifdef MONITORS
|
||||
@@ -567,6 +524,8 @@ void bli_sgemm_knc_asm_30x16
|
||||
#endif
|
||||
}
|
||||
|
||||
GEMM_UKR_FLUSH_CT( s );
|
||||
|
||||
#ifdef LOOPMON
|
||||
printf("looptime = \t%d\n", bloopl - tloopl);
|
||||
#endif
|
||||
|
||||
@@ -185,6 +185,8 @@ static int32_t offsets[32] __attribute__((aligned(64))) =
|
||||
//#define LOOPMON
|
||||
void bli_dgemm_knl_asm_24x8
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k_,
|
||||
double* restrict alpha,
|
||||
double* restrict a,
|
||||
@@ -201,10 +203,12 @@ void bli_dgemm_knl_asm_24x8
|
||||
const double * a_next = bli_auxinfo_next_a( data );
|
||||
const double * b_next = bli_auxinfo_next_b( data );
|
||||
|
||||
const int32_t * offsetPtr = &offsets[0];
|
||||
const int64_t k = k_;
|
||||
const int64_t rs_c = rs_c_;
|
||||
const int64_t cs_c = cs_c_;
|
||||
int32_t * offsetPtr = &offsets[0];
|
||||
int64_t k = k_;
|
||||
int64_t rs_c = rs_c_;
|
||||
int64_t cs_c = cs_c_;
|
||||
|
||||
GEMM_UKR_SETUP_CT( d, 24, 8, true );
|
||||
|
||||
#ifdef MONITORS
|
||||
int toph, topl, both, botl, midl, midh, mid2l, mid2h;
|
||||
@@ -565,10 +569,7 @@ void bli_dgemm_knl_asm_24x8
|
||||
// Check if C is row stride. If not, jump to the slow scattered update
|
||||
MOV(RAX, VAR(rs_c))
|
||||
LEA(RAX, MEM(,RAX,8))
|
||||
MOV(RBX, VAR(cs_c))
|
||||
LEA(RDI, MEM(RAX,RAX,2))
|
||||
CMP(RBX, IMM(1))
|
||||
JNE(SCATTEREDUPDATE)
|
||||
|
||||
VMOVQ(RDX, XMM(1))
|
||||
SAL(RDX) //shift out sign bit
|
||||
@@ -592,74 +593,6 @@ void bli_dgemm_knl_asm_24x8
|
||||
UPDATE_C_BZ_FOUR_ROWS(24,25,26,27)
|
||||
UPDATE_C_BZ_FOUR_ROWS(28,29,30,31)
|
||||
|
||||
JMP(END)
|
||||
|
||||
LABEL(SCATTEREDUPDATE)
|
||||
|
||||
MOV(RDI, VAR(offsetPtr))
|
||||
VMOVAPS(ZMM(2), MEM(RDI))
|
||||
/* Note that this ignores the upper 32 bits in cs_c */
|
||||
VPBROADCASTD(ZMM(3), EBX)
|
||||
VPMULLD(ZMM(2), ZMM(3), ZMM(2))
|
||||
|
||||
VMOVQ(RDX, XMM(1))
|
||||
SAL(RDX) //shift out sign bit
|
||||
JZ(SCATTERBZ)
|
||||
|
||||
UPDATE_C_ROW_SCATTERED( 8)
|
||||
UPDATE_C_ROW_SCATTERED( 9)
|
||||
UPDATE_C_ROW_SCATTERED(10)
|
||||
UPDATE_C_ROW_SCATTERED(11)
|
||||
UPDATE_C_ROW_SCATTERED(12)
|
||||
UPDATE_C_ROW_SCATTERED(13)
|
||||
UPDATE_C_ROW_SCATTERED(14)
|
||||
UPDATE_C_ROW_SCATTERED(15)
|
||||
UPDATE_C_ROW_SCATTERED(16)
|
||||
UPDATE_C_ROW_SCATTERED(17)
|
||||
UPDATE_C_ROW_SCATTERED(18)
|
||||
UPDATE_C_ROW_SCATTERED(19)
|
||||
UPDATE_C_ROW_SCATTERED(20)
|
||||
UPDATE_C_ROW_SCATTERED(21)
|
||||
UPDATE_C_ROW_SCATTERED(22)
|
||||
UPDATE_C_ROW_SCATTERED(23)
|
||||
UPDATE_C_ROW_SCATTERED(24)
|
||||
UPDATE_C_ROW_SCATTERED(25)
|
||||
UPDATE_C_ROW_SCATTERED(26)
|
||||
UPDATE_C_ROW_SCATTERED(27)
|
||||
UPDATE_C_ROW_SCATTERED(28)
|
||||
UPDATE_C_ROW_SCATTERED(29)
|
||||
UPDATE_C_ROW_SCATTERED(30)
|
||||
UPDATE_C_ROW_SCATTERED(31)
|
||||
|
||||
JMP(END)
|
||||
|
||||
LABEL(SCATTERBZ)
|
||||
|
||||
UPDATE_C_BZ_ROW_SCATTERED( 8)
|
||||
UPDATE_C_BZ_ROW_SCATTERED( 9)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(10)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(11)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(12)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(13)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(14)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(15)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(16)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(17)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(18)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(19)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(20)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(21)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(22)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(23)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(24)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(25)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(26)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(27)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(28)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(29)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(30)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(31)
|
||||
|
||||
LABEL(END)
|
||||
|
||||
#ifdef MONITORS
|
||||
@@ -701,6 +634,8 @@ void bli_dgemm_knl_asm_24x8
|
||||
"zmm30", "zmm31", "memory"
|
||||
)
|
||||
|
||||
GEMM_UKR_FLUSH_CT( d );
|
||||
|
||||
#ifdef LOOPMON
|
||||
printf("looptime = \t%d\n", bloopl - tloopl);
|
||||
#endif
|
||||
|
||||
@@ -182,6 +182,8 @@ static int32_t offsets[32] __attribute__((aligned(64))) =
|
||||
//#define LOOPMON
|
||||
void bli_sgemm_knl_asm_24x16
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k_,
|
||||
float* restrict alpha,
|
||||
float* restrict a,
|
||||
@@ -198,10 +200,12 @@ void bli_sgemm_knl_asm_24x16
|
||||
const double * a_next = bli_auxinfo_next_a( data );
|
||||
const double * b_next = bli_auxinfo_next_b( data );
|
||||
|
||||
const int32_t * offsetPtr = &offsets[0];
|
||||
const int64_t k = k_;
|
||||
const int64_t rs_c = rs_c_;
|
||||
const int64_t cs_c = cs_c_;
|
||||
int32_t * offsetPtr = &offsets[0];
|
||||
int64_t k = k_;
|
||||
int64_t rs_c = rs_c_;
|
||||
int64_t cs_c = cs_c_;
|
||||
|
||||
GEMM_UKR_SETUP_CT( s, 24, 16, true );
|
||||
|
||||
#ifdef MONITORS
|
||||
int toph, topl, both, botl, midl, midh, mid2l, mid2h;
|
||||
@@ -562,10 +566,7 @@ void bli_sgemm_knl_asm_24x16
|
||||
// Check if C is row stride. If not, jump to the slow scattered update
|
||||
MOV(RAX, VAR(rs_c))
|
||||
LEA(RAX, MEM(,RAX,4))
|
||||
MOV(RBX, VAR(cs_c))
|
||||
LEA(RDI, MEM(RAX,RAX,2))
|
||||
CMP(RBX, IMM(1))
|
||||
JNE(SCATTEREDUPDATE)
|
||||
|
||||
VMOVD(EDX, XMM(1))
|
||||
SAL(EDX) //shift out sign bit
|
||||
@@ -589,74 +590,6 @@ void bli_sgemm_knl_asm_24x16
|
||||
UPDATE_C_BZ_FOUR_ROWS(24,25,26,27)
|
||||
UPDATE_C_BZ_FOUR_ROWS(28,29,30,31)
|
||||
|
||||
JMP(END)
|
||||
|
||||
LABEL(SCATTEREDUPDATE)
|
||||
|
||||
MOV(RDI, VAR(offsetPtr))
|
||||
VMOVAPS(ZMM(2), MEM(RDI))
|
||||
/* Note that this ignores the upper 32 bits in cs_c */
|
||||
VPBROADCASTD(ZMM(3), EBX)
|
||||
VPMULLD(ZMM(2), ZMM(3), ZMM(2))
|
||||
|
||||
VMOVD(EDX, XMM(1))
|
||||
SAL(EDX) //shift out sign bit
|
||||
JZ(SCATTERBZ)
|
||||
|
||||
UPDATE_C_ROW_SCATTERED( 8)
|
||||
UPDATE_C_ROW_SCATTERED( 9)
|
||||
UPDATE_C_ROW_SCATTERED(10)
|
||||
UPDATE_C_ROW_SCATTERED(11)
|
||||
UPDATE_C_ROW_SCATTERED(12)
|
||||
UPDATE_C_ROW_SCATTERED(13)
|
||||
UPDATE_C_ROW_SCATTERED(14)
|
||||
UPDATE_C_ROW_SCATTERED(15)
|
||||
UPDATE_C_ROW_SCATTERED(16)
|
||||
UPDATE_C_ROW_SCATTERED(17)
|
||||
UPDATE_C_ROW_SCATTERED(18)
|
||||
UPDATE_C_ROW_SCATTERED(19)
|
||||
UPDATE_C_ROW_SCATTERED(20)
|
||||
UPDATE_C_ROW_SCATTERED(21)
|
||||
UPDATE_C_ROW_SCATTERED(22)
|
||||
UPDATE_C_ROW_SCATTERED(23)
|
||||
UPDATE_C_ROW_SCATTERED(24)
|
||||
UPDATE_C_ROW_SCATTERED(25)
|
||||
UPDATE_C_ROW_SCATTERED(26)
|
||||
UPDATE_C_ROW_SCATTERED(27)
|
||||
UPDATE_C_ROW_SCATTERED(28)
|
||||
UPDATE_C_ROW_SCATTERED(29)
|
||||
UPDATE_C_ROW_SCATTERED(30)
|
||||
UPDATE_C_ROW_SCATTERED(31)
|
||||
|
||||
JMP(END)
|
||||
|
||||
LABEL(SCATTERBZ)
|
||||
|
||||
UPDATE_C_BZ_ROW_SCATTERED( 8)
|
||||
UPDATE_C_BZ_ROW_SCATTERED( 9)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(10)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(11)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(12)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(13)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(14)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(15)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(16)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(17)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(18)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(19)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(20)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(21)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(22)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(23)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(24)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(25)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(26)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(27)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(28)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(29)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(30)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(31)
|
||||
|
||||
LABEL(END)
|
||||
|
||||
#ifdef MONITORS
|
||||
@@ -698,6 +631,8 @@ void bli_sgemm_knl_asm_24x16
|
||||
"zmm30", "zmm31", "memory"
|
||||
)
|
||||
|
||||
GEMM_UKR_FLUSH_CT( s );
|
||||
|
||||
#ifdef LOOPMON
|
||||
printf("looptime = \t%d\n", bloopl - tloopl);
|
||||
#endif
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -37,7 +37,7 @@
|
||||
|
||||
#define D_ASSEMBLE_VEC_PAIR \
|
||||
__builtin_mma_assemble_pair (&colA_1, ca[1], ca[0]); \
|
||||
__builtin_mma_assemble_pair (&colA_2, ca[3], ca[2]);
|
||||
__builtin_mma_assemble_pair (&colA_2, ca[3], ca[2]);
|
||||
|
||||
#define D_ACCUMULATE \
|
||||
__builtin_mma_xvf64gerpp (&acc0, colA_1, rb[0]); \
|
||||
@@ -47,7 +47,7 @@
|
||||
__builtin_mma_xvf64gerpp (&acc4, colA_2, rb[0]); \
|
||||
__builtin_mma_xvf64gerpp (&acc5, colA_2, rb[1]); \
|
||||
__builtin_mma_xvf64gerpp (&acc6, colA_2, rb[2]); \
|
||||
__builtin_mma_xvf64gerpp (&acc7, colA_2, rb[3]);
|
||||
__builtin_mma_xvf64gerpp (&acc7, colA_2, rb[3]);
|
||||
|
||||
#define D_INCREMENT \
|
||||
A0+=8; \
|
||||
@@ -57,17 +57,19 @@
|
||||
LOAD_VECTORS \
|
||||
D_ASSEMBLE_VEC_PAIR \
|
||||
D_INCREMENT \
|
||||
D_ACCUMULATE
|
||||
D_ACCUMULATE
|
||||
|
||||
|
||||
void bli_dgemm_power10_mma_8x8
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
double* restrict alpha,
|
||||
double* restrict a,
|
||||
double* restrict b,
|
||||
double* restrict beta,
|
||||
double* restrict c, inc_t rs_c0, inc_t cs_c0,
|
||||
double* restrict c, inc_t rs_c0, inc_t cs_c,
|
||||
auxinfo_t* restrict data,
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
@@ -76,11 +78,13 @@ void bli_dgemm_power10_mma_8x8
|
||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||
// different size than is expected by load instructions.
|
||||
// (1 is subtracted from k0 because 1 iteration of the k loop is pulled out)
|
||||
uint64_t k_iter = (k0-1) / 4;
|
||||
uint64_t k_left = (k0-1) % 4;
|
||||
uint64_t k_iter = (k-1) / 4;
|
||||
uint64_t k_left = (k-1) % 4;
|
||||
|
||||
uint64_t rs_c = rs_c0;
|
||||
|
||||
GEMM_UKR_SETUP_CT( d, 8, 8, true );
|
||||
|
||||
double* restrict A0 = a;
|
||||
double* restrict B0 = b;
|
||||
double* restrict C0 = c;
|
||||
@@ -92,23 +96,23 @@ void bli_dgemm_power10_mma_8x8
|
||||
dv4sf_t *rowC;
|
||||
|
||||
/* 8 accumulator registers that will be used to store the result.
|
||||
|
||||
|
||||
Each accumulator register is mapped to 4 vector registers.
|
||||
Illustration:
|
||||
|
||||
|
||||
acc0 = [ vs0
|
||||
vs1
|
||||
vs3
|
||||
vs4 ]
|
||||
|
||||
These registers are used to store the result of an outer product
|
||||
These registers are used to store the result of an outer product
|
||||
instruction (general outer product instruction syntax: xv???ger??). */
|
||||
__vector_quad acc0, acc1, acc2, acc3,
|
||||
__vector_quad acc0, acc1, acc2, acc3,
|
||||
acc4, acc5, acc6, acc7;
|
||||
|
||||
/* 2 vector pairs are necessary for a double precision outer product
|
||||
/* 2 vector pairs are necessary for a double precision outer product
|
||||
instruction. */
|
||||
__vector_pair colA_1,
|
||||
__vector_pair colA_1,
|
||||
colA_2;
|
||||
|
||||
/* Prefetch C so that it stays in cache */
|
||||
@@ -123,17 +127,17 @@ void bli_dgemm_power10_mma_8x8
|
||||
|
||||
/* Load elements into vector registers */
|
||||
vec_t *ca = (vec_t *) A0;
|
||||
vec_t *rb = (vec_t *) B0;
|
||||
vec_t *rb = (vec_t *) B0;
|
||||
|
||||
/* Each accumulator represents a matrix of size
|
||||
/* Each accumulator represents a matrix of size
|
||||
4 x ( 16 / (datatype size in bytes) ) (vector register size = 16B)
|
||||
|
||||
Thus in the case of double, the accumulate registers represent a 4x2
|
||||
Thus in the case of double, the accumulate registers represent a 4x2
|
||||
matrix. However, a vector register can hold at most 2 doubles. Thus, if
|
||||
we performed an outer product using 2 vector register, we can only get a
|
||||
we performed an outer product using 2 vector register, we can only get a
|
||||
2x2 matrix. Therefore, we must create a vector register pair in order
|
||||
to get the desired 4x2 matrix.
|
||||
|
||||
|
||||
*/
|
||||
D_ASSEMBLE_VEC_PAIR
|
||||
|
||||
@@ -158,7 +162,7 @@ void bli_dgemm_power10_mma_8x8
|
||||
D_AB_PRODUCT
|
||||
D_AB_PRODUCT
|
||||
}
|
||||
|
||||
|
||||
// edge loop
|
||||
for (int k = 0; k<k_left; k++)
|
||||
{
|
||||
@@ -189,4 +193,5 @@ void bli_dgemm_power10_mma_8x8
|
||||
SAVE_ACC_bz(dv4sf_t, &acc7, rs_c, 6+4*rs_c);
|
||||
}
|
||||
|
||||
GEMM_UKR_FLUSH_CT( d );
|
||||
}
|
||||
|
||||
@@ -55,7 +55,9 @@
|
||||
|
||||
void bli_i16gemm_power10_mma_8x16
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
int32_t* restrict alpha,
|
||||
short* restrict a,
|
||||
short* restrict b,
|
||||
@@ -66,8 +68,8 @@ void bli_i16gemm_power10_mma_8x16
|
||||
)
|
||||
{
|
||||
|
||||
uint64_t k_iter = (k0-1) / 4;
|
||||
uint64_t k_left = (k0-1) % 4;
|
||||
uint64_t k_iter = (k-1) / 4;
|
||||
uint64_t k_left = (k-1) % 4;
|
||||
|
||||
uint64_t rs_c = rs_c0;
|
||||
|
||||
@@ -82,7 +84,7 @@ void bli_i16gemm_power10_mma_8x16
|
||||
iv4sf_t *rowC;
|
||||
|
||||
// accumulators that will hold the matrix product
|
||||
__vector_quad acc0, acc1, acc2, acc3,
|
||||
__vector_quad acc0, acc1, acc2, acc3,
|
||||
acc4, acc5, acc6, acc7;
|
||||
|
||||
vec_t *ca = (vec_t *) A0;
|
||||
|
||||
@@ -55,7 +55,9 @@
|
||||
|
||||
void bli_i16sgemm_power10_mma_8x16
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
int32_t* restrict alpha,
|
||||
short* restrict a,
|
||||
short* restrict b,
|
||||
@@ -66,8 +68,8 @@ void bli_i16sgemm_power10_mma_8x16
|
||||
)
|
||||
{
|
||||
|
||||
uint64_t k_iter = (k0-1) / 4;
|
||||
uint64_t k_left = (k0-1) % 4;
|
||||
uint64_t k_iter = (k-1) / 4;
|
||||
uint64_t k_left = (k-1) % 4;
|
||||
|
||||
uint64_t rs_c = rs_c0;
|
||||
|
||||
@@ -82,7 +84,7 @@ void bli_i16sgemm_power10_mma_8x16
|
||||
iv4sf_t *rowC;
|
||||
|
||||
// accumulators that will hold the matrix product
|
||||
__vector_quad acc0, acc1, acc2, acc3,
|
||||
__vector_quad acc0, acc1, acc2, acc3,
|
||||
acc4, acc5, acc6, acc7;
|
||||
|
||||
vec_t *ca = (vec_t *) A0;
|
||||
|
||||
@@ -55,7 +55,9 @@
|
||||
|
||||
void bli_i4gemm_power10_mma_8x16
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
int32_t* restrict alpha,
|
||||
nibbles* restrict a,
|
||||
nibbles* restrict b,
|
||||
@@ -66,8 +68,8 @@ void bli_i4gemm_power10_mma_8x16
|
||||
)
|
||||
{
|
||||
|
||||
uint64_t k_iter = (k0-1) / 4;
|
||||
uint64_t k_left = (k0-1) % 4;
|
||||
uint64_t k_iter = (k-1) / 4;
|
||||
uint64_t k_left = (k-1) % 4;
|
||||
|
||||
uint64_t rs_c = rs_c0;
|
||||
|
||||
@@ -82,11 +84,11 @@ void bli_i4gemm_power10_mma_8x16
|
||||
iv4sf_t *rowC;
|
||||
|
||||
// accumulators that will hold the matrix product
|
||||
__vector_quad acc0, acc1, acc2, acc3,
|
||||
__vector_quad acc0, acc1, acc2, acc3,
|
||||
acc4, acc5, acc6, acc7;
|
||||
|
||||
vec_t *ca = (vec_t *) A0;
|
||||
vec_t *rb = (vec_t *) B0;
|
||||
vec_t *rb = (vec_t *) B0;
|
||||
|
||||
__builtin_mma_xvi4ger8 (&acc0, ca[0], rb[0]);
|
||||
__builtin_mma_xvi4ger8 (&acc1, ca[0], rb[1]);
|
||||
@@ -96,23 +98,23 @@ void bli_i4gemm_power10_mma_8x16
|
||||
__builtin_mma_xvi4ger8 (&acc5, ca[1], rb[1]);
|
||||
__builtin_mma_xvi4ger8 (&acc6, ca[1], rb[2]);
|
||||
__builtin_mma_xvi4ger8 (&acc7, ca[1], rb[3]);
|
||||
|
||||
|
||||
I4_INCREMENT
|
||||
|
||||
// k loop (unrolled by 4)
|
||||
for (int k = 0; k<k_iter; k++)
|
||||
{
|
||||
I4_AB_PRODUCT
|
||||
I4_AB_PRODUCT
|
||||
I4_AB_PRODUCT
|
||||
I4_AB_PRODUCT
|
||||
}
|
||||
|
||||
// edge loop
|
||||
for (int k = 0; k<k_left; k++)
|
||||
{
|
||||
I4_AB_PRODUCT
|
||||
}
|
||||
for (int k = 0; k<k_iter; k++)
|
||||
{
|
||||
I4_AB_PRODUCT
|
||||
I4_AB_PRODUCT
|
||||
I4_AB_PRODUCT
|
||||
I4_AB_PRODUCT
|
||||
}
|
||||
|
||||
// edge loop
|
||||
for (int k = 0; k<k_left; k++)
|
||||
{
|
||||
I4_AB_PRODUCT
|
||||
}
|
||||
|
||||
// handle beta cases
|
||||
if (beta_ != 0.0)
|
||||
|
||||
@@ -55,7 +55,9 @@
|
||||
|
||||
void bli_i8gemm_power10_mma_8x16
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
int32_t* restrict alpha,
|
||||
int8_t* restrict a,
|
||||
int8_t* restrict b,
|
||||
@@ -65,8 +67,8 @@ void bli_i8gemm_power10_mma_8x16
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
uint64_t k_iter = (k0-1) / 4;
|
||||
uint64_t k_left = (k0-1) % 4;
|
||||
uint64_t k_iter = (k-1) / 4;
|
||||
uint64_t k_left = (k-1) % 4;
|
||||
|
||||
uint64_t rs_c = rs_c0;
|
||||
|
||||
@@ -81,11 +83,11 @@ void bli_i8gemm_power10_mma_8x16
|
||||
iv4sf_t *rowC;
|
||||
|
||||
// accumulators that will hold the matrix product
|
||||
__vector_quad acc0, acc1, acc2, acc3,
|
||||
__vector_quad acc0, acc1, acc2, acc3,
|
||||
acc4, acc5, acc6, acc7;
|
||||
|
||||
vec_t *ca = (vec_t *) A0;
|
||||
vec_t *rb = (vec_t *) B0;
|
||||
vec_t *rb = (vec_t *) B0;
|
||||
|
||||
__builtin_mma_xvi8ger4 (&acc0, ca[0], rb[0]);
|
||||
__builtin_mma_xvi8ger4 (&acc1, ca[0], rb[1]);
|
||||
@@ -99,19 +101,19 @@ void bli_i8gemm_power10_mma_8x16
|
||||
I8_INCREMENT
|
||||
|
||||
// k loop (unrolled by 4)
|
||||
for (int k = 0; k<k_iter; k++)
|
||||
{
|
||||
I8_AB_PRODUCT
|
||||
I8_AB_PRODUCT
|
||||
I8_AB_PRODUCT
|
||||
I8_AB_PRODUCT
|
||||
}
|
||||
|
||||
// edge loop
|
||||
for (int k = 0; k<k_left; k++)
|
||||
{
|
||||
I8_AB_PRODUCT
|
||||
}
|
||||
for (int k = 0; k<k_iter; k++)
|
||||
{
|
||||
I8_AB_PRODUCT
|
||||
I8_AB_PRODUCT
|
||||
I8_AB_PRODUCT
|
||||
I8_AB_PRODUCT
|
||||
}
|
||||
|
||||
// edge loop
|
||||
for (int k = 0; k<k_left; k++)
|
||||
{
|
||||
I8_AB_PRODUCT
|
||||
}
|
||||
|
||||
// handle beta cases
|
||||
if (beta_ != 0.0)
|
||||
|
||||
@@ -42,21 +42,23 @@
|
||||
__builtin_mma_xvbf16ger2pp (&acc4, ca[1], rb[0]); \
|
||||
__builtin_mma_xvbf16ger2pp (&acc5, ca[1], rb[1]); \
|
||||
__builtin_mma_xvbf16ger2pp (&acc6, ca[1], rb[2]); \
|
||||
__builtin_mma_xvbf16ger2pp (&acc7, ca[1], rb[3]);
|
||||
__builtin_mma_xvbf16ger2pp (&acc7, ca[1], rb[3]);
|
||||
|
||||
#define B_INCREMENT \
|
||||
A0+=16; \
|
||||
B0+=32;
|
||||
|
||||
B0+=32;
|
||||
|
||||
#define B_AB_PRODUCT \
|
||||
LOAD_VECTORS \
|
||||
B_INCREMENT \
|
||||
B_ACCUMULATE
|
||||
B_ACCUMULATE
|
||||
|
||||
|
||||
void bli_sbgemm_power10_mma_8x16
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
float* restrict alpha,
|
||||
bfloat16* restrict a,
|
||||
bfloat16* restrict b,
|
||||
@@ -67,8 +69,8 @@ void bli_sbgemm_power10_mma_8x16
|
||||
)
|
||||
{
|
||||
|
||||
uint64_t k_iter = (k0-1)/4;
|
||||
uint64_t k_left = (k0-1)%4;
|
||||
uint64_t k_iter = (k-1)/4;
|
||||
uint64_t k_left = (k-1)%4;
|
||||
|
||||
uint64_t rs_c = rs_c0;
|
||||
|
||||
@@ -83,7 +85,7 @@ void bli_sbgemm_power10_mma_8x16
|
||||
fv4sf_t *rowC;
|
||||
|
||||
// accumulators that will hold the matrix product
|
||||
__vector_quad acc0, acc1, acc2, acc3,
|
||||
__vector_quad acc0, acc1, acc2, acc3,
|
||||
acc4, acc5, acc6, acc7;
|
||||
|
||||
vec_t *ca = (vec_t *) A0;
|
||||
|
||||
@@ -42,7 +42,7 @@
|
||||
__builtin_mma_xvf32gerpp (&acc4, ca[1], rb[0]); \
|
||||
__builtin_mma_xvf32gerpp (&acc5, ca[1], rb[1]); \
|
||||
__builtin_mma_xvf32gerpp (&acc6, ca[1], rb[2]); \
|
||||
__builtin_mma_xvf32gerpp (&acc7, ca[1], rb[3]);
|
||||
__builtin_mma_xvf32gerpp (&acc7, ca[1], rb[3]);
|
||||
|
||||
#define S_INCREMENT \
|
||||
A0+=8; \
|
||||
@@ -51,16 +51,18 @@
|
||||
#define S_AB_PRODUCT \
|
||||
LOAD_VECTORS \
|
||||
S_INCREMENT \
|
||||
S_ACCUMULATE
|
||||
S_ACCUMULATE
|
||||
|
||||
void bli_sgemm_power10_mma_8x16
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
float* restrict alpha,
|
||||
float* restrict a,
|
||||
float* restrict b,
|
||||
float* restrict beta,
|
||||
float* restrict c, inc_t rs_c0, inc_t cs_c0,
|
||||
float* restrict c, inc_t rs_c0, inc_t cs_c,
|
||||
auxinfo_t* restrict data,
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
@@ -68,16 +70,18 @@ void bli_sgemm_power10_mma_8x16
|
||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||
// different size than is expected by load instructions.
|
||||
// (1 is subtracted from k0 because 1 iteration of the k loop is pulled out)
|
||||
uint64_t k_iter = (k0-1) / 4;
|
||||
uint64_t k_left = (k0-1) % 4;
|
||||
|
||||
uint64_t k_iter = (k-1) / 4;
|
||||
uint64_t k_left = (k-1) % 4;
|
||||
|
||||
uint64_t rs_c = rs_c0;
|
||||
|
||||
GEMM_UKR_SETUP_CT( s, 8, 16, true );
|
||||
|
||||
fv4sf_t result[4];
|
||||
fv4sf_t *rowC;
|
||||
|
||||
// accumulators that will hold the matrix product
|
||||
__vector_quad acc0, acc1, acc2, acc3,
|
||||
__vector_quad acc0, acc1, acc2, acc3,
|
||||
acc4, acc5, acc6, acc7;
|
||||
|
||||
float* restrict A0 = a;
|
||||
@@ -111,7 +115,7 @@ void bli_sgemm_power10_mma_8x16
|
||||
S_AB_PRODUCT
|
||||
S_AB_PRODUCT
|
||||
}
|
||||
|
||||
|
||||
// edge loop
|
||||
for (int k = 0; k<k_left; k++)
|
||||
{
|
||||
@@ -141,4 +145,6 @@ void bli_sgemm_power10_mma_8x16
|
||||
SAVE_ACC_bz(fv4sf_t, &acc6, rs_c, 8+4*rs_c);
|
||||
SAVE_ACC_bz(fv4sf_t, &acc7, rs_c, 12+4*rs_c);
|
||||
}
|
||||
|
||||
GEMM_UKR_FLUSH_CT( s );
|
||||
}
|
||||
@@ -42,21 +42,23 @@
|
||||
__builtin_mma_xvf16ger2pp (&acc4, ca[1], rb[0]); \
|
||||
__builtin_mma_xvf16ger2pp (&acc5, ca[1], rb[1]); \
|
||||
__builtin_mma_xvf16ger2pp (&acc6, ca[1], rb[2]); \
|
||||
__builtin_mma_xvf16ger2pp (&acc7, ca[1], rb[3]);
|
||||
__builtin_mma_xvf16ger2pp (&acc7, ca[1], rb[3]);
|
||||
|
||||
#define H_INCREMENT \
|
||||
A0+=16; \
|
||||
B0+=32;
|
||||
|
||||
B0+=32;
|
||||
|
||||
#define H_AB_PRODUCT \
|
||||
LOAD_VECTORS \
|
||||
H_INCREMENT \
|
||||
H_ACCUMULATE
|
||||
H_ACCUMULATE
|
||||
|
||||
|
||||
void bli_shgemm_power10_mma_8x16
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
float* restrict alpha,
|
||||
float16* restrict a,
|
||||
float16* restrict b,
|
||||
@@ -67,8 +69,8 @@ void bli_shgemm_power10_mma_8x16
|
||||
)
|
||||
{
|
||||
|
||||
uint64_t k_iter = (k0-1)/4;
|
||||
uint64_t k_left = (k0-1)%4;
|
||||
uint64_t k_iter = (k-1)/4;
|
||||
uint64_t k_left = (k-1)%4;
|
||||
|
||||
uint64_t rs_c = rs_c0;
|
||||
|
||||
@@ -83,7 +85,7 @@ void bli_shgemm_power10_mma_8x16
|
||||
fv4sf_t *rowC;
|
||||
|
||||
// accumulators that will hold the matrix product
|
||||
__vector_quad acc0, acc1, acc2, acc3,
|
||||
__vector_quad acc0, acc1, acc2, acc3,
|
||||
acc4, acc5, acc6, acc7;
|
||||
|
||||
vec_t *ca = (vec_t *) A0;
|
||||
|
||||
@@ -50,32 +50,28 @@
|
||||
*/
|
||||
void bli_sgemm_power7_int_8x4
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
float* restrict alpha,
|
||||
float* restrict a,
|
||||
float* restrict b,
|
||||
float* restrict beta,
|
||||
float* restrict c, inc_t rs_c0, inc_t cs_c0,
|
||||
float* restrict c, inc_t rs_c, inc_t cs_c,
|
||||
auxinfo_t* restrict data,
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||
// different size than is expected by load instructions.
|
||||
uint64_t k = k0;
|
||||
uint64_t rs_c = rs_c0;
|
||||
uint64_t cs_c = cs_c0;
|
||||
|
||||
#if 1 || defined(UTEST)
|
||||
const long MR = BLIS_DEFAULT_MR_S, NR = BLIS_DEFAULT_NR_S;
|
||||
const long LDA = MR, LDB = NR;
|
||||
long i, j, kk;
|
||||
float c00;
|
||||
|
||||
for (i=0; i < MR; i++) {
|
||||
for (j=0; j < NR; j++) {
|
||||
for (i=0; i < m; i++) {
|
||||
for (j=0; j < n; j++) {
|
||||
c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta;
|
||||
for (kk=0; kk < k; kk++)
|
||||
for (kk=0; kk < k; kk++)
|
||||
c00 += *alpha * (a[COLMAJ_INDEX(i,kk,LDA)] * b[ROWMAJ_INDEX(kk,j,LDB)]);
|
||||
c[BLIS_INDEX(i,j,rs_c,cs_c)] = c00;
|
||||
}
|
||||
@@ -96,24 +92,160 @@ void bli_sgemm_power7_int_8x4
|
||||
*/
|
||||
void bli_dgemm_power7_int_8x4
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
double* restrict alpha,
|
||||
double* restrict a,
|
||||
double* restrict b,
|
||||
double* restrict beta,
|
||||
double* restrict c, inc_t rs_c0, inc_t cs_c0,
|
||||
double* restrict c, inc_t rs_c, inc_t cs_c,
|
||||
auxinfo_t* restrict data,
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||
// different size than is expected by load instructions.
|
||||
uint64_t k = k0;
|
||||
uint64_t rs_c = rs_c0;
|
||||
uint64_t cs_c = cs_c0;
|
||||
if ( cs_c == 1 )
|
||||
{
|
||||
// Optimized code for case where C rows are contiguous (i.e. C is row-major)
|
||||
|
||||
vector double vzero = vec_splats( 0.0 );
|
||||
|
||||
vector double vc00_01 = vzero;
|
||||
vector double vc02_03 = vzero;
|
||||
vector double vc10_11 = vzero;
|
||||
vector double vc12_13 = vzero;
|
||||
vector double vc20_21 = vzero;
|
||||
vector double vc22_23 = vzero;
|
||||
vector double vc30_31 = vzero;
|
||||
vector double vc32_33 = vzero;
|
||||
vector double vc40_41 = vzero;
|
||||
vector double vc42_43 = vzero;
|
||||
vector double vc50_51 = vzero;
|
||||
vector double vc52_53 = vzero;
|
||||
vector double vc60_61 = vzero;
|
||||
vector double vc62_63 = vzero;
|
||||
vector double vc70_71 = vzero;
|
||||
vector double vc72_73 = vzero;
|
||||
|
||||
unsigned long long pa = (unsigned long long)a;
|
||||
unsigned long long pb = (unsigned long long)b;
|
||||
|
||||
#if 0
|
||||
unsigned long long d1 = 1*sizeof(double);
|
||||
unsigned long long d2 = 2*sizeof(double);
|
||||
unsigned long long d3 = 3*sizeof(double);
|
||||
unsigned long long d4 = 4*sizeof(double);
|
||||
unsigned long long d6 = 6*sizeof(double);
|
||||
#else
|
||||
// ppc64 linux abi: r14-r31 Nonvolatile registers used for local variables
|
||||
register unsigned long long d1 __asm ("r21") = 1*sizeof(double);
|
||||
register unsigned long long d2 __asm ("r22") = 2*sizeof(double);
|
||||
register unsigned long long d3 __asm ("r23") = 3*sizeof(double);
|
||||
register unsigned long long d4 __asm ("r24") = 4*sizeof(double);
|
||||
register unsigned long long d5 __asm ("r25") = 5*sizeof(double);
|
||||
register unsigned long long d6 __asm ("r26") = 6*sizeof(double);
|
||||
register unsigned long long d7 __asm ("r27") = 7*sizeof(double);
|
||||
|
||||
__asm__ volatile (";" : "=r" (d1) : "r" (d1) );
|
||||
__asm__ volatile (";" : "=r" (d2) : "r" (d2) );
|
||||
__asm__ volatile (";" : "=r" (d3) : "r" (d3) );
|
||||
__asm__ volatile (";" : "=r" (d4) : "r" (d4) );
|
||||
__asm__ volatile (";" : "=r" (d5) : "r" (d5) );
|
||||
__asm__ volatile (";" : "=r" (d6) : "r" (d6) );
|
||||
__asm__ volatile (";" : "=r" (d7) : "r" (d7) );
|
||||
#endif
|
||||
|
||||
int kk;
|
||||
for (kk=k; kk > 0; kk--) {
|
||||
vector double va00 = vec_splats( *(double *)( pa+0 ) );
|
||||
vector double va10 = vec_splats( *(double *)( pa+d1 ) );
|
||||
vector double va20 = vec_splats( *(double *)( pa+d2 ) );
|
||||
vector double va30 = vec_splats( *(double *)( pa+d3 ) );
|
||||
vector double va40 = vec_splats( *(double *)( pa+d4 ) );
|
||||
vector double va50 = vec_splats( *(double *)( pa+d5 ) );
|
||||
vector double va60 = vec_splats( *(double *)( pa+d6 ) );
|
||||
vector double va70 = vec_splats( *(double *)( pa+d7 ) );
|
||||
pa += 8*sizeof(double);
|
||||
|
||||
vector double vb00_01 = *(vector double *)( pb+0 );
|
||||
vector double vb02_03 = *(vector double *)( pb+d2 );
|
||||
pb += 4*sizeof(double);
|
||||
|
||||
vc00_01 = vec_madd(va00, vb00_01, vc00_01);
|
||||
vc02_03 = vec_madd(va00, vb02_03, vc02_03);
|
||||
vc10_11 = vec_madd(va10, vb00_01, vc10_11);
|
||||
vc12_13 = vec_madd(va10, vb02_03, vc12_13);
|
||||
vc20_21 = vec_madd(va20, vb00_01, vc20_21);
|
||||
vc22_23 = vec_madd(va20, vb02_03, vc22_23);
|
||||
vc30_31 = vec_madd(va30, vb00_01, vc30_31);
|
||||
vc32_33 = vec_madd(va30, vb02_03, vc32_33);
|
||||
vc40_41 = vec_madd(va40, vb00_01, vc40_41);
|
||||
vc42_43 = vec_madd(va40, vb02_03, vc42_43);
|
||||
vc50_51 = vec_madd(va50, vb00_01, vc50_51);
|
||||
vc52_53 = vec_madd(va50, vb02_03, vc52_53);
|
||||
vc60_61 = vec_madd(va60, vb00_01, vc60_61);
|
||||
vc62_63 = vec_madd(va60, vb02_03, vc62_63);
|
||||
vc70_71 = vec_madd(va70, vb00_01, vc70_71);
|
||||
vc72_73 = vec_madd(va70, vb02_03, vc72_73);
|
||||
}
|
||||
|
||||
vector double valpha = vec_splats( *alpha );
|
||||
vector double vbeta = (vector double) { *beta, *beta };
|
||||
|
||||
vector double *pc = (vector double *)c;
|
||||
|
||||
vc00_01 = vec_mul(valpha, vc00_01);
|
||||
vc02_03 = vec_mul(valpha, vc02_03);
|
||||
pc[0] = vec_madd( pc[0], vbeta, vc00_01);
|
||||
pc[1] = vec_madd( pc[1], vbeta, vc02_03);
|
||||
pc += rs_c/2;
|
||||
|
||||
vc10_11 = vec_mul(valpha, vc10_11);
|
||||
vc12_13 = vec_mul(valpha, vc12_13);
|
||||
pc[0] = vec_madd( pc[0], vbeta, vc10_11);
|
||||
pc[1] = vec_madd( pc[1], vbeta, vc12_13);
|
||||
pc += rs_c/2;
|
||||
|
||||
vc20_21 = vec_mul(valpha, vc20_21);
|
||||
vc22_23 = vec_mul(valpha, vc22_23);
|
||||
pc[0] = vec_madd( pc[0], vbeta, vc20_21);
|
||||
pc[1] = vec_madd( pc[1], vbeta, vc22_23);
|
||||
pc += rs_c/2;
|
||||
|
||||
vc30_31 = vec_mul(valpha, vc30_31);
|
||||
vc32_33 = vec_mul(valpha, vc32_33);
|
||||
pc[0] = vec_madd( pc[0], vbeta, vc30_31);
|
||||
pc[1] = vec_madd( pc[1], vbeta, vc32_33);
|
||||
pc += rs_c/2;
|
||||
|
||||
vc40_41 = vec_mul(valpha, vc40_41);
|
||||
vc42_43 = vec_mul(valpha, vc42_43);
|
||||
pc[0] = vec_madd( pc[0], vbeta, vc40_41);
|
||||
pc[1] = vec_madd( pc[1], vbeta, vc42_43);
|
||||
pc += rs_c/2;
|
||||
|
||||
vc50_51 = vec_mul(valpha, vc50_51);
|
||||
vc52_53 = vec_mul(valpha, vc52_53);
|
||||
pc[0] = vec_madd( pc[0], vbeta, vc50_51);
|
||||
pc[1] = vec_madd( pc[1], vbeta, vc52_53);
|
||||
pc += rs_c/2;
|
||||
|
||||
vc60_61 = vec_mul(valpha, vc60_61);
|
||||
vc62_63 = vec_mul(valpha, vc62_63);
|
||||
pc[0] = vec_madd( pc[0], vbeta, vc60_61);
|
||||
pc[1] = vec_madd( pc[1], vbeta, vc62_63);
|
||||
pc += rs_c/2;
|
||||
|
||||
vc70_71 = vec_mul(valpha, vc70_71);
|
||||
vc72_73 = vec_mul(valpha, vc72_73);
|
||||
pc[0] = vec_madd( pc[0], vbeta, vc70_71);
|
||||
pc[1] = vec_madd( pc[1], vbeta, vc72_73);
|
||||
pc += rs_c/2;
|
||||
}
|
||||
else
|
||||
{
|
||||
GEMM_UKR_SETUP_CT( d, 8, 4, false );
|
||||
|
||||
#if 1
|
||||
if (rs_c == 1) {
|
||||
// Optimized code for case where C columns are contiguous (column-major C)
|
||||
vector double vzero = vec_splats( 0.0 );
|
||||
|
||||
@@ -301,168 +433,8 @@ void bli_dgemm_power7_int_8x4
|
||||
pc[1] = vec_madd( pc[1], vbeta, vc23_33);
|
||||
pc[2] = vec_madd( pc[2], vbeta, vc43_53);
|
||||
pc[3] = vec_madd( pc[3], vbeta, vc63_73);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
#if 1
|
||||
if ( cs_c == 1 ) {
|
||||
// Optimized code for case where C rows are contiguous (i.e. C is row-major)
|
||||
|
||||
vector double vzero = vec_splats( 0.0 );
|
||||
|
||||
vector double vc00_01 = vzero;
|
||||
vector double vc02_03 = vzero;
|
||||
vector double vc10_11 = vzero;
|
||||
vector double vc12_13 = vzero;
|
||||
vector double vc20_21 = vzero;
|
||||
vector double vc22_23 = vzero;
|
||||
vector double vc30_31 = vzero;
|
||||
vector double vc32_33 = vzero;
|
||||
vector double vc40_41 = vzero;
|
||||
vector double vc42_43 = vzero;
|
||||
vector double vc50_51 = vzero;
|
||||
vector double vc52_53 = vzero;
|
||||
vector double vc60_61 = vzero;
|
||||
vector double vc62_63 = vzero;
|
||||
vector double vc70_71 = vzero;
|
||||
vector double vc72_73 = vzero;
|
||||
|
||||
unsigned long long pa = (unsigned long long)a;
|
||||
unsigned long long pb = (unsigned long long)b;
|
||||
|
||||
#if 0
|
||||
unsigned long long d1 = 1*sizeof(double);
|
||||
unsigned long long d2 = 2*sizeof(double);
|
||||
unsigned long long d3 = 3*sizeof(double);
|
||||
unsigned long long d4 = 4*sizeof(double);
|
||||
unsigned long long d6 = 6*sizeof(double);
|
||||
#else
|
||||
// ppc64 linux abi: r14-r31 Nonvolatile registers used for local variables
|
||||
register unsigned long long d1 __asm ("r21") = 1*sizeof(double);
|
||||
register unsigned long long d2 __asm ("r22") = 2*sizeof(double);
|
||||
register unsigned long long d3 __asm ("r23") = 3*sizeof(double);
|
||||
register unsigned long long d4 __asm ("r24") = 4*sizeof(double);
|
||||
register unsigned long long d5 __asm ("r25") = 5*sizeof(double);
|
||||
register unsigned long long d6 __asm ("r26") = 6*sizeof(double);
|
||||
register unsigned long long d7 __asm ("r27") = 7*sizeof(double);
|
||||
|
||||
__asm__ volatile (";" : "=r" (d1) : "r" (d1) );
|
||||
__asm__ volatile (";" : "=r" (d2) : "r" (d2) );
|
||||
__asm__ volatile (";" : "=r" (d3) : "r" (d3) );
|
||||
__asm__ volatile (";" : "=r" (d4) : "r" (d4) );
|
||||
__asm__ volatile (";" : "=r" (d5) : "r" (d5) );
|
||||
__asm__ volatile (";" : "=r" (d6) : "r" (d6) );
|
||||
__asm__ volatile (";" : "=r" (d7) : "r" (d7) );
|
||||
#endif
|
||||
|
||||
int kk;
|
||||
for (kk=k; kk > 0; kk--) {
|
||||
vector double va00 = vec_splats( *(double *)( pa+0 ) );
|
||||
vector double va10 = vec_splats( *(double *)( pa+d1 ) );
|
||||
vector double va20 = vec_splats( *(double *)( pa+d2 ) );
|
||||
vector double va30 = vec_splats( *(double *)( pa+d3 ) );
|
||||
vector double va40 = vec_splats( *(double *)( pa+d4 ) );
|
||||
vector double va50 = vec_splats( *(double *)( pa+d5 ) );
|
||||
vector double va60 = vec_splats( *(double *)( pa+d6 ) );
|
||||
vector double va70 = vec_splats( *(double *)( pa+d7 ) );
|
||||
pa += 8*sizeof(double);
|
||||
|
||||
vector double vb00_01 = *(vector double *)( pb+0 );
|
||||
vector double vb02_03 = *(vector double *)( pb+d2 );
|
||||
pb += 4*sizeof(double);
|
||||
|
||||
vc00_01 = vec_madd(va00, vb00_01, vc00_01);
|
||||
vc02_03 = vec_madd(va00, vb02_03, vc02_03);
|
||||
vc10_11 = vec_madd(va10, vb00_01, vc10_11);
|
||||
vc12_13 = vec_madd(va10, vb02_03, vc12_13);
|
||||
vc20_21 = vec_madd(va20, vb00_01, vc20_21);
|
||||
vc22_23 = vec_madd(va20, vb02_03, vc22_23);
|
||||
vc30_31 = vec_madd(va30, vb00_01, vc30_31);
|
||||
vc32_33 = vec_madd(va30, vb02_03, vc32_33);
|
||||
vc40_41 = vec_madd(va40, vb00_01, vc40_41);
|
||||
vc42_43 = vec_madd(va40, vb02_03, vc42_43);
|
||||
vc50_51 = vec_madd(va50, vb00_01, vc50_51);
|
||||
vc52_53 = vec_madd(va50, vb02_03, vc52_53);
|
||||
vc60_61 = vec_madd(va60, vb00_01, vc60_61);
|
||||
vc62_63 = vec_madd(va60, vb02_03, vc62_63);
|
||||
vc70_71 = vec_madd(va70, vb00_01, vc70_71);
|
||||
vc72_73 = vec_madd(va70, vb02_03, vc72_73);
|
||||
}
|
||||
|
||||
vector double valpha = vec_splats( *alpha );
|
||||
vector double vbeta = (vector double) { *beta, *beta };
|
||||
|
||||
vector double *pc = (vector double *)c;
|
||||
|
||||
vc00_01 = vec_mul(valpha, vc00_01);
|
||||
vc02_03 = vec_mul(valpha, vc02_03);
|
||||
pc[0] = vec_madd( pc[0], vbeta, vc00_01);
|
||||
pc[1] = vec_madd( pc[1], vbeta, vc02_03);
|
||||
pc += rs_c/2;
|
||||
|
||||
vc10_11 = vec_mul(valpha, vc10_11);
|
||||
vc12_13 = vec_mul(valpha, vc12_13);
|
||||
pc[0] = vec_madd( pc[0], vbeta, vc10_11);
|
||||
pc[1] = vec_madd( pc[1], vbeta, vc12_13);
|
||||
pc += rs_c/2;
|
||||
|
||||
vc20_21 = vec_mul(valpha, vc20_21);
|
||||
vc22_23 = vec_mul(valpha, vc22_23);
|
||||
pc[0] = vec_madd( pc[0], vbeta, vc20_21);
|
||||
pc[1] = vec_madd( pc[1], vbeta, vc22_23);
|
||||
pc += rs_c/2;
|
||||
|
||||
vc30_31 = vec_mul(valpha, vc30_31);
|
||||
vc32_33 = vec_mul(valpha, vc32_33);
|
||||
pc[0] = vec_madd( pc[0], vbeta, vc30_31);
|
||||
pc[1] = vec_madd( pc[1], vbeta, vc32_33);
|
||||
pc += rs_c/2;
|
||||
|
||||
vc40_41 = vec_mul(valpha, vc40_41);
|
||||
vc42_43 = vec_mul(valpha, vc42_43);
|
||||
pc[0] = vec_madd( pc[0], vbeta, vc40_41);
|
||||
pc[1] = vec_madd( pc[1], vbeta, vc42_43);
|
||||
pc += rs_c/2;
|
||||
|
||||
vc50_51 = vec_mul(valpha, vc50_51);
|
||||
vc52_53 = vec_mul(valpha, vc52_53);
|
||||
pc[0] = vec_madd( pc[0], vbeta, vc50_51);
|
||||
pc[1] = vec_madd( pc[1], vbeta, vc52_53);
|
||||
pc += rs_c/2;
|
||||
|
||||
vc60_61 = vec_mul(valpha, vc60_61);
|
||||
vc62_63 = vec_mul(valpha, vc62_63);
|
||||
pc[0] = vec_madd( pc[0], vbeta, vc60_61);
|
||||
pc[1] = vec_madd( pc[1], vbeta, vc62_63);
|
||||
pc += rs_c/2;
|
||||
|
||||
vc70_71 = vec_mul(valpha, vc70_71);
|
||||
vc72_73 = vec_mul(valpha, vc72_73);
|
||||
pc[0] = vec_madd( pc[0], vbeta, vc70_71);
|
||||
pc[1] = vec_madd( pc[1], vbeta, vc72_73);
|
||||
pc += rs_c/2;
|
||||
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{ /* General case. Just do it right. */
|
||||
#if 1 || defined(UTEST)
|
||||
const long MR = BLIS_DEFAULT_MR_D, NR = BLIS_DEFAULT_NR_D;
|
||||
const long LDA = MR, LDB = NR;
|
||||
int i, j, kk;
|
||||
double c00;
|
||||
|
||||
for (i=0; i < MR; i++) {
|
||||
for (j=0; j < NR; j++) {
|
||||
c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta;
|
||||
for (kk=0; kk < k; kk++)
|
||||
c00 += *alpha * (a[COLMAJ_INDEX(i,kk,LDA)] * b[ROWMAJ_INDEX(kk,j,LDB)]);
|
||||
c[BLIS_INDEX(i,j,rs_c,cs_c)] = c00;
|
||||
}
|
||||
}
|
||||
#else
|
||||
//BLIS_DGEMM_UKERNEL_REF(k, alpha, a, b, beta, c, rs_c, cs_c, data);
|
||||
#endif
|
||||
GEMM_UKR_FLUSH_CT( d );
|
||||
}
|
||||
}
|
||||
|
||||
@@ -477,30 +449,26 @@ void bli_dgemm_power7_int_8x4
|
||||
*/
|
||||
void bli_cgemm_power7_int_8x4
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
scomplex* restrict alpha,
|
||||
scomplex* restrict a,
|
||||
scomplex* restrict b,
|
||||
scomplex* restrict beta,
|
||||
scomplex* restrict c, inc_t rs_c0, inc_t cs_c0,
|
||||
scomplex* restrict c, inc_t rs_c, inc_t cs_c,
|
||||
auxinfo_t* restrict data,
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||
// different size than is expected by load instructions.
|
||||
uint64_t k = k0;
|
||||
uint64_t rs_c = rs_c0;
|
||||
uint64_t cs_c = cs_c0;
|
||||
|
||||
#if 1 || defined(UTEST)
|
||||
const long MR = BLIS_DEFAULT_MR_C, NR = BLIS_DEFAULT_NR_C;
|
||||
const long LDA = MR, LDB = NR;
|
||||
int i, j, kk;
|
||||
scomplex c00;
|
||||
|
||||
for (i=0; i < MR; i++) {
|
||||
for (j=0; j < NR; j++) {
|
||||
for (i=0; i < m; i++) {
|
||||
for (j=0; j < n; j++) {
|
||||
scomplex tmpc, tmpa, tmpb, tmp;
|
||||
//c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta;
|
||||
tmpc = c[BLIS_INDEX(i,j,rs_c,cs_c)];
|
||||
@@ -534,30 +502,26 @@ void bli_cgemm_power7_int_8x4
|
||||
*/
|
||||
void bli_zgemm_power7_int_8x4
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
scomplex* restrict alpha,
|
||||
scomplex* restrict a,
|
||||
scomplex* restrict b,
|
||||
scomplex* restrict beta,
|
||||
scomplex* restrict c, inc_t rs_c0, inc_t cs_c0,
|
||||
scomplex* restrict c, inc_t rs_c, inc_t cs_c,
|
||||
auxinfo_t* restrict data,
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||
// different size than is expected by load instructions.
|
||||
uint64_t k = k0;
|
||||
uint64_t rs_c = rs_c0;
|
||||
uint64_t cs_c = cs_c0;
|
||||
|
||||
#if 1 || defined(UTEST)
|
||||
const long MR = BLIS_DEFAULT_MR_Z, NR = BLIS_DEFAULT_NR_Z;
|
||||
const long LDA = MR, LDB = NR;
|
||||
int i, j, kk;
|
||||
dcomplex c00;
|
||||
|
||||
for (i=0; i < MR; i++) {
|
||||
for (j=0; j < NR; j++) {
|
||||
for (i=0; i < m; i++) {
|
||||
for (j=0; j < n; j++) {
|
||||
dcomplex tmpc, tmpa, tmpb, tmp;
|
||||
//c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta;
|
||||
tmpc = c[BLIS_INDEX(i,j,rs_c,cs_c)];
|
||||
|
||||
@@ -43,6 +43,8 @@
|
||||
|
||||
void bli_sgemm_opt_8x4
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
float* restrict alpha,
|
||||
float* restrict a,
|
||||
@@ -55,6 +57,8 @@ void bli_sgemm_opt_8x4
|
||||
|
||||
void bli_dgemm_opt_8x4
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
double* restrict alpha,
|
||||
double* restrict a,
|
||||
@@ -67,6 +71,8 @@ void bli_dgemm_opt_8x4
|
||||
|
||||
void bli_cgemm_opt_8x4
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
scomplex* restrict alpha,
|
||||
scomplex* restrict a,
|
||||
@@ -79,6 +85,8 @@ void bli_cgemm_opt_8x4
|
||||
|
||||
void bli_zgemm_opt_8x4
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
dcomplex* restrict alpha,
|
||||
dcomplex* restrict a,
|
||||
|
||||
@@ -37,7 +37,9 @@
|
||||
|
||||
void bli_dgemm_power9_asm_12x6
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
double* restrict alpha,
|
||||
double* restrict a,
|
||||
double* restrict b,
|
||||
@@ -50,117 +52,91 @@ void bli_dgemm_power9_asm_12x6
|
||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||
// different size than is expected by load instructions.
|
||||
|
||||
uint64_t k_iter = k0 / 16;
|
||||
uint64_t k_left = k0 % 16;
|
||||
uint64_t k_iter = k / 16;
|
||||
uint64_t k_left = k % 16;
|
||||
|
||||
uint64_t rs_c = rs_c0;
|
||||
uint64_t rs_c = rs_c0;
|
||||
uint64_t cs_c = cs_c0;
|
||||
|
||||
GEMM_UKR_SETUP_CT( d, 12, 6, false );
|
||||
|
||||
__asm__ volatile
|
||||
(
|
||||
" \n\t"
|
||||
"ld %%r7, %2 \n\t" // load ptr of A
|
||||
"ld %%r8, %3 \n\t" // load ptr of B
|
||||
"ld %%r16, %6 \n\t" // load ptr of C
|
||||
" \n\t"
|
||||
"ld %%r28, %4 \n\t" // load ptr for alpha
|
||||
"ld %%r29, %5 \n\t" // load ptr for beta
|
||||
" \n\t"
|
||||
"ld %%r11, %0 \n\t" // load k_iter
|
||||
"ld %%r12, %1 \n\t" // load k_left
|
||||
" \n\t"
|
||||
"ld %%r10, %8 \n\t" // load cs_c
|
||||
"slwi %%r10, %%r10, 3 \n\t" // mul by size of elem
|
||||
" \n\t"
|
||||
"ld %%r9, %7 \n\t" // load rs_c
|
||||
"slwi %%r9, %%r9, 3 \n\t" // mul by size of elem
|
||||
" \n\t"
|
||||
"ld %%r26, 0(%%r29) \n\t" // load val of beta
|
||||
" \n\t"
|
||||
"lxvdsx %%vs62, 0, %%r28 \n\t" // splat alpha
|
||||
"lxvdsx %%vs63, 0, %%r29 \n\t" // splat beta
|
||||
" \n\t"
|
||||
"add %%r17, %%r16, %%r10 \n\t" // addr of col 1 of C
|
||||
"add %%r18, %%r17, %%r10 \n\t" // col 2 of C
|
||||
"add %%r19, %%r18, %%r10 \n\t" // col 3 of C
|
||||
"add %%r20, %%r19, %%r10 \n\t" // col 4 of C
|
||||
"add %%r21, %%r20, %%r10 \n\t" // col 5 of C
|
||||
" \n\t"
|
||||
DZERO_OUT_VREG
|
||||
" \n\t"
|
||||
DPRELOAD
|
||||
" \n\t"
|
||||
"addi %%r8, %%r8, 96 \n\t" // move to next col/row of A/B
|
||||
"addi %%r7, %%r7, 96 \n\t"
|
||||
" \n\t"
|
||||
DPREFETCH
|
||||
" \n\t"
|
||||
"cmpwi %%r11, 0 \n\t" // if k_iter == 0,
|
||||
"beq DCONSIDERKLEFT \n\t" // then jmp to k_left
|
||||
"mtctr %%r11 \n\t" // else, do k_iter loop
|
||||
" \n\t"
|
||||
"DLOOPKITER: \n\t" // k_iter loop
|
||||
" \n\t"
|
||||
A_B_PRODUCT_16 // compute A*B
|
||||
" \n\t"
|
||||
"bdnz DLOOPKITER \n\t"
|
||||
" \n\t"
|
||||
"DCONSIDERKLEFT: \n\t"
|
||||
" \n\t"
|
||||
"cmpwi %%r12, 0 \n\t" // if k_left == 0,
|
||||
"beq DPOSTACCUM \n\t" // then jmp to post accum
|
||||
"mtctr %%r12 \n\t" // else, do k_left loop
|
||||
" \n\t"
|
||||
"DLOOPKLEFT: \n\t" // k_left loop
|
||||
" \n\t"
|
||||
A_B_PRODUCT_1
|
||||
" \n\t"
|
||||
"bdnz DLOOPKLEFT \n\t"
|
||||
" \n\t"
|
||||
"DPOSTACCUM: \n\t"
|
||||
" \n\t"
|
||||
DSCALE_ALPHA
|
||||
" \n\t"
|
||||
"cmpdi %%r26, 0 \n\t" // if beta == 0,
|
||||
"beq DBETAZERO \n\t" // then jmp to BZ
|
||||
" \n\t"
|
||||
"cmpwi %%r9, 8 \n\t" // if rs_c == 8
|
||||
"beq DCOLSTOREDBNZ \n\t" // then jmp to col store
|
||||
" \n\t"
|
||||
"DGENSTOREDBNZ: \n\t" // BNZ gen stored case
|
||||
" \n\t"
|
||||
DGEN_LOAD_OFS_C
|
||||
" \n\t"
|
||||
DGEN_SCALE_BETA
|
||||
" \n\t"
|
||||
"b DGENSTORED \n\t"
|
||||
" \n\t"
|
||||
"DCOLSTOREDBNZ: \n\t" // BNZ col stored case
|
||||
" \n\t"
|
||||
DCOL_SCALE_BETA
|
||||
" \n\t"
|
||||
"b DCOLSTORED \n\t"
|
||||
" \n\t"
|
||||
"DBETAZERO: \n\t" // BZ case
|
||||
" \n\t"
|
||||
"cmpwi %%r9, 8 \n\t" // if rs_c == 8,
|
||||
"beq DCOLSTORED \n\t" // C is col stored
|
||||
" \n\t"
|
||||
"DGENSTORED: \n\t" // BZ gen stored case
|
||||
" \n\t"
|
||||
DGEN_LOAD_OFS_C
|
||||
" \n\t"
|
||||
DGEN_STORE
|
||||
" \n\t"
|
||||
"b DDONE \n\t"
|
||||
" \n\t"
|
||||
"DCOLSTORED: \n\t" // BZ col stored case
|
||||
" \n\t"
|
||||
DCOL_STORE
|
||||
" \n\t"
|
||||
"DDONE: \n\t"
|
||||
" \n\t"
|
||||
: // output operands (none)
|
||||
(
|
||||
" \n\t"
|
||||
"ld %%r7, %2 \n\t" // load ptr of A
|
||||
"ld %%r8, %3 \n\t" // load ptr of B
|
||||
"ld %%r16, %6 \n\t" // load ptr of C
|
||||
" \n\t"
|
||||
"ld %%r28, %4 \n\t" // load ptr for alpha
|
||||
"ld %%r29, %5 \n\t" // load ptr for beta
|
||||
" \n\t"
|
||||
"ld %%r11, %0 \n\t" // load k_iter
|
||||
"ld %%r12, %1 \n\t" // load k_left
|
||||
" \n\t"
|
||||
"ld %%r10, %8 \n\t" // load cs_c
|
||||
"slwi %%r10, %%r10, 3 \n\t" // mul by size of elem
|
||||
" \n\t"
|
||||
"ld %%r9, %7 \n\t" // load rs_c
|
||||
"slwi %%r9, %%r9, 3 \n\t" // mul by size of elem
|
||||
" \n\t"
|
||||
"ld %%r26, 0(%%r29) \n\t" // load val of beta
|
||||
" \n\t"
|
||||
"lxvdsx %%vs62, 0, %%r28 \n\t" // splat alpha
|
||||
"lxvdsx %%vs63, 0, %%r29 \n\t" // splat beta
|
||||
" \n\t"
|
||||
"add %%r17, %%r16, %%r10 \n\t" // addr of col 1 of C
|
||||
"add %%r18, %%r17, %%r10 \n\t" // col 2 of C
|
||||
"add %%r19, %%r18, %%r10 \n\t" // col 3 of C
|
||||
"add %%r20, %%r19, %%r10 \n\t" // col 4 of C
|
||||
"add %%r21, %%r20, %%r10 \n\t" // col 5 of C
|
||||
" \n\t"
|
||||
DZERO_OUT_VREG
|
||||
" \n\t"
|
||||
DPRELOAD
|
||||
" \n\t"
|
||||
"addi %%r8, %%r8, 96 \n\t" // move to next col/row of A/B
|
||||
"addi %%r7, %%r7, 96 \n\t"
|
||||
" \n\t"
|
||||
DPREFETCH
|
||||
" \n\t"
|
||||
"cmpwi %%r11, 0 \n\t" // if k_iter == 0,
|
||||
"beq DCONSIDERKLEFT \n\t" // then jmp to k_left
|
||||
"mtctr %%r11 \n\t" // else, do k_iter loop
|
||||
" \n\t"
|
||||
"DLOOPKITER: \n\t" // k_iter loop
|
||||
" \n\t"
|
||||
A_B_PRODUCT_16 // compute A*B
|
||||
" \n\t"
|
||||
"bdnz DLOOPKITER \n\t"
|
||||
" \n\t"
|
||||
"DCONSIDERKLEFT: \n\t"
|
||||
" \n\t"
|
||||
"cmpwi %%r12, 0 \n\t" // if k_left == 0,
|
||||
"beq DPOSTACCUM \n\t" // then jmp to post accum
|
||||
"mtctr %%r12 \n\t" // else, do k_left loop
|
||||
" \n\t"
|
||||
"DLOOPKLEFT: \n\t" // k_left loop
|
||||
" \n\t"
|
||||
A_B_PRODUCT_1
|
||||
" \n\t"
|
||||
"bdnz DLOOPKLEFT \n\t"
|
||||
" \n\t"
|
||||
"DPOSTACCUM: \n\t"
|
||||
" \n\t"
|
||||
DSCALE_ALPHA
|
||||
" \n\t"
|
||||
"cmpdi %%r26, 0 \n\t" // if beta == 0,
|
||||
"beq DBETAZERO \n\t" // then jmp to BZ
|
||||
" \n\t"
|
||||
DCOL_SCALE_BETA
|
||||
" \n\t"
|
||||
"DBETAZERO: \n\t" // BZ case
|
||||
" \n\t"
|
||||
DCOL_STORE
|
||||
" \n\t"
|
||||
"DDONE: \n\t"
|
||||
" \n\t"
|
||||
: // output operands (none)
|
||||
: // input operands
|
||||
"m" (k_iter), // 0
|
||||
"m" (k_left), // 1
|
||||
@@ -174,28 +150,30 @@ void bli_dgemm_power9_asm_12x6
|
||||
"m" (b_next), // 9
|
||||
"m" (a_next)*/ // 10
|
||||
: // register clobber list
|
||||
/* unclobberable regs: r2, r3, r4, r5, r6, r13, r14, r15, r30, r31 */
|
||||
"r0", "r7", "r8", "r9",
|
||||
"r10", "r11", "r12", "r16", "r17", "r18", "r19",
|
||||
"r20", "r21", "r22", "r23", "r24", "r25", "r26", "r27", "r28", "r29"
|
||||
/* unclobberable regs: r2, r3, r4, r5, r6, r13, r14, r15, r30, r31 */
|
||||
"r0", "r7", "r8", "r9",
|
||||
"r10", "r11", "r12", "r16", "r17", "r18", "r19",
|
||||
"r20", "r21", "r22", "r23", "r24", "r25", "r26", "r27", "r28", "r29"
|
||||
|
||||
#if XLC
|
||||
,"f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9"
|
||||
, "f10", "f11", "f12", "f13", "f14", "f15", "f16", "f17", "f18", "f19"
|
||||
, "f20" ,"f21", "f22", "f23", "f24", "f25", "f26", "f27", "f28", "f29"
|
||||
, "f30" ,"f31"
|
||||
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9"
|
||||
, "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19"
|
||||
, "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29"
|
||||
, "v30", "v31"
|
||||
#else
|
||||
, "vs0", "vs1", "vs2", "vs3", "vs4", "vs5", "vs6", "vs7", "vs8", "vs9"
|
||||
, "vs10", "vs11", "vs12", "vs13", "vs14", "vs15", "vs16", "vs17", "vs18", "vs19"
|
||||
, "vs20", "vs21", "vs22", "vs23", "vs24", "vs25", "vs26", "vs27", "vs28", "vs29"
|
||||
, "vs30", "vs31", "vs32", "vs33", "vs34", "vs35", "vs36", "vs37", "vs38", "vs39"
|
||||
, "vs40", "vs41", "vs42", "vs43", "vs44", "vs45", "vs46", "vs47", "vs48", "vs49"
|
||||
, "vs50", "vs51", "vs52", "vs53"
|
||||
#endif
|
||||
#if XLC
|
||||
,"f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9"
|
||||
, "f10", "f11", "f12", "f13", "f14", "f15", "f16", "f17", "f18", "f19"
|
||||
, "f20" ,"f21", "f22", "f23", "f24", "f25", "f26", "f27", "f28", "f29"
|
||||
, "f30" ,"f31"
|
||||
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9"
|
||||
, "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19"
|
||||
, "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29"
|
||||
, "v30", "v31"
|
||||
#else
|
||||
, "vs0", "vs1", "vs2", "vs3", "vs4", "vs5", "vs6", "vs7", "vs8", "vs9"
|
||||
, "vs10", "vs11", "vs12", "vs13", "vs14", "vs15", "vs16", "vs17", "vs18", "vs19"
|
||||
, "vs20", "vs21", "vs22", "vs23", "vs24", "vs25", "vs26", "vs27", "vs28", "vs29"
|
||||
, "vs30", "vs31", "vs32", "vs33", "vs34", "vs35", "vs36", "vs37", "vs38", "vs39"
|
||||
, "vs40", "vs41", "vs42", "vs43", "vs44", "vs45", "vs46", "vs47", "vs48", "vs49"
|
||||
, "vs50", "vs51", "vs52", "vs53"
|
||||
#endif
|
||||
|
||||
);
|
||||
);
|
||||
|
||||
GEMM_UKR_FLUSH_CT( d );
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -32,14 +32,17 @@
|
||||
|
||||
*/
|
||||
|
||||
#include <immintrin.h>
|
||||
#include <emmintrin.h>
|
||||
#include <immintrin.h>
|
||||
#include "blis.h"
|
||||
|
||||
|
||||
#if 0
|
||||
void bli_sgemm_sandybridge_int_8x8
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
float* restrict alpha,
|
||||
float* restrict a,
|
||||
float* restrict b,
|
||||
@@ -52,11 +55,11 @@ void bli_sgemm_sandybridge_int_8x8
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
void bli_dgemm_sandybridge_int_8x4
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
double* restrict alpha,
|
||||
double* restrict a,
|
||||
double* restrict b,
|
||||
@@ -66,19 +69,22 @@ void bli_dgemm_sandybridge_int_8x4
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
|
||||
//void* a_next = bli_auxinfo_next_a( data );
|
||||
void* b_next = bli_auxinfo_next_b( data );
|
||||
|
||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||
// different size than is expected by load instructions.
|
||||
uint64_t k_iter = k0 / 2;
|
||||
uint64_t k_left = k0 % 2;
|
||||
uint64_t k_iter = k / 2;
|
||||
uint64_t k_left = k % 2;
|
||||
uint64_t rs_c = rs_c0;
|
||||
uint64_t cs_c = cs_c0;
|
||||
uint64_t i;
|
||||
|
||||
double *c00, *c01, *c02, *c03;
|
||||
double *c40, *c41, *c42, *c43;
|
||||
GEMM_UKR_SETUP_CT( d, 8, 4, false );
|
||||
|
||||
double *c00, *c01, *c02, *c03;
|
||||
double *c40, *c41, *c42, *c43;
|
||||
|
||||
// Quad registers.
|
||||
__m256d va0_3, va4_7;
|
||||
@@ -87,23 +93,20 @@ void bli_dgemm_sandybridge_int_8x4
|
||||
__m256d vb;
|
||||
__m256d vB0;
|
||||
|
||||
__m256d va0_3b_0, va4_7b_0;
|
||||
__m256d va0_3b_1, va4_7b_1;
|
||||
__m256d va0_3b_2, va4_7b_2;
|
||||
__m256d va0_3b_3, va4_7b_3;
|
||||
__m256d va0_3b_0, va4_7b_0;
|
||||
__m256d va0_3b_1, va4_7b_1;
|
||||
__m256d va0_3b_2, va4_7b_2;
|
||||
__m256d va0_3b_3, va4_7b_3;
|
||||
|
||||
__m256d va0_3b0, va4_7b0;
|
||||
__m256d va0_3b1, va4_7b1;
|
||||
__m256d va0_3b2, va4_7b2;
|
||||
__m256d va0_3b3, va4_7b3;
|
||||
__m256d va0_3b0, va4_7b0;
|
||||
__m256d va0_3b1, va4_7b1;
|
||||
__m256d va0_3b2, va4_7b2;
|
||||
__m256d va0_3b3, va4_7b3;
|
||||
|
||||
|
||||
__m256d valpha, vbeta, vtmp;
|
||||
__m256d valpha, vbeta, vtmp;
|
||||
__m256d vc0_3_0, vc0_3_1, vc0_3_2, vc0_3_3;
|
||||
__m256d vc4_7_0, vc4_7_1, vc4_7_2, vc4_7_3;
|
||||
|
||||
__m128d aa, bb;
|
||||
|
||||
__asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"(a) );
|
||||
__asm__ volatile( "prefetcht2 0(%0) \n\t" : :"r"(b_next) );
|
||||
__asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"(c) );
|
||||
@@ -129,19 +132,19 @@ void bli_dgemm_sandybridge_int_8x4
|
||||
va4_7b_3 = _mm256_setzero_pd();
|
||||
|
||||
// Load va0_3
|
||||
va0_3 = _mm256_load_pd( a );
|
||||
va0_3 = _mm256_load_pd( a );
|
||||
// Load va4_7
|
||||
va4_7 = _mm256_load_pd( a + 4 );
|
||||
va4_7 = _mm256_load_pd( a + 4 );
|
||||
|
||||
// Load vb (b0,b1,b2,b3)
|
||||
vb0 = _mm256_load_pd( b );
|
||||
// Load vb (b0,b1,b2,b3)
|
||||
vb0 = _mm256_load_pd( b );
|
||||
|
||||
for( i = 0; i < k_iter; ++i )
|
||||
{
|
||||
__asm__ volatile( "prefetcht0 192(%0) \n\t" : :"r"(a) );
|
||||
|
||||
// Load va0_3 (Prefetch)
|
||||
vA0_3 = _mm256_load_pd( a + 8 );
|
||||
vA0_3 = _mm256_load_pd( a + 8 );
|
||||
|
||||
// Iteration 0.
|
||||
vtmp = _mm256_mul_pd( va0_3, vb0 );
|
||||
@@ -151,10 +154,10 @@ void bli_dgemm_sandybridge_int_8x4
|
||||
va4_7b_0 = _mm256_add_pd( va4_7b_0, vtmp );
|
||||
|
||||
// Load va4_7 (Prefetch)
|
||||
vA4_7 = _mm256_load_pd( a + 12 );
|
||||
vA4_7 = _mm256_load_pd( a + 12 );
|
||||
|
||||
// Shuffle vb (b1,b0,b3,b2)
|
||||
vb1 = _mm256_shuffle_pd( vb0, vb0, 0x5 );
|
||||
vb1 = _mm256_shuffle_pd( vb0, vb0, 0x5 );
|
||||
|
||||
vtmp = _mm256_mul_pd( va0_3, vb1 );
|
||||
va0_3b_1 = _mm256_add_pd( va0_3b_1, vtmp );
|
||||
@@ -163,10 +166,10 @@ void bli_dgemm_sandybridge_int_8x4
|
||||
va4_7b_1 = _mm256_add_pd( va4_7b_1, vtmp );
|
||||
|
||||
// Permute vb (b3,b2,b1,b0)
|
||||
vb2 = _mm256_permute2f128_pd( vb1, vb1, 0x1 );
|
||||
vb2 = _mm256_permute2f128_pd( vb1, vb1, 0x1 );
|
||||
|
||||
// Load vb (b0,b1,b2,b3) (Prefetch)
|
||||
vB0 = _mm256_load_pd( b + 4 );
|
||||
vB0 = _mm256_load_pd( b + 4 );
|
||||
|
||||
vtmp = _mm256_mul_pd( va0_3, vb2 );
|
||||
va0_3b_2 = _mm256_add_pd( va0_3b_2, vtmp );
|
||||
@@ -175,7 +178,7 @@ void bli_dgemm_sandybridge_int_8x4
|
||||
va4_7b_2 = _mm256_add_pd( va4_7b_2, vtmp );
|
||||
|
||||
// Shuffle vb (b3,b2,b1,b0)
|
||||
vb3 = _mm256_shuffle_pd( vb2, vb2, 0x5 );
|
||||
vb3 = _mm256_shuffle_pd( vb2, vb2, 0x5 );
|
||||
|
||||
vtmp = _mm256_mul_pd( va0_3, vb3 );
|
||||
va0_3b_3 = _mm256_add_pd( va0_3b_3, vtmp );
|
||||
@@ -186,14 +189,14 @@ void bli_dgemm_sandybridge_int_8x4
|
||||
// Iteration 1.
|
||||
|
||||
__asm__ volatile( "prefetcht0 512(%0) \n\t" : :"r"(a) );
|
||||
|
||||
|
||||
// Load va0_3 (Next iteration)
|
||||
va0_3 = _mm256_load_pd( a + 16 );
|
||||
va0_3 = _mm256_load_pd( a + 16 );
|
||||
|
||||
vtmp = _mm256_mul_pd( vA0_3, vB0 );
|
||||
va0_3b_0 = _mm256_add_pd( va0_3b_0, vtmp );
|
||||
|
||||
vb1 = _mm256_shuffle_pd( vB0, vB0, 0x5 );
|
||||
vb1 = _mm256_shuffle_pd( vB0, vB0, 0x5 );
|
||||
|
||||
vtmp = _mm256_mul_pd( vA4_7, vB0 );
|
||||
va4_7b_0 = _mm256_add_pd( va4_7b_0, vtmp );
|
||||
@@ -202,9 +205,9 @@ void bli_dgemm_sandybridge_int_8x4
|
||||
va0_3b_1 = _mm256_add_pd( va0_3b_1, vtmp );
|
||||
|
||||
// Load va4_7 (Next iteration)
|
||||
va4_7 = _mm256_load_pd( a + 20 );
|
||||
va4_7 = _mm256_load_pd( a + 20 );
|
||||
|
||||
vb2 = _mm256_permute2f128_pd( vb1, vb1, 0x1 );
|
||||
vb2 = _mm256_permute2f128_pd( vb1, vb1, 0x1 );
|
||||
|
||||
vtmp = _mm256_mul_pd( vA4_7, vb1 );
|
||||
va4_7b_1 = _mm256_add_pd( va4_7b_1, vtmp );
|
||||
@@ -212,13 +215,13 @@ void bli_dgemm_sandybridge_int_8x4
|
||||
vtmp = _mm256_mul_pd( vA0_3, vb2 );
|
||||
va0_3b_2 = _mm256_add_pd( va0_3b_2, vtmp );
|
||||
|
||||
vb3 = _mm256_shuffle_pd( vb2, vb2, 0x5 );
|
||||
vb3 = _mm256_shuffle_pd( vb2, vb2, 0x5 );
|
||||
|
||||
vtmp = _mm256_mul_pd( vA4_7, vb2 );
|
||||
va4_7b_2 = _mm256_add_pd( va4_7b_2, vtmp );
|
||||
|
||||
// Load vb0(Next iteration)
|
||||
vb0 = _mm256_load_pd( b + 8 );
|
||||
vb0 = _mm256_load_pd( b + 8 );
|
||||
|
||||
vtmp = _mm256_mul_pd( vA0_3, vb3 );
|
||||
va0_3b_3 = _mm256_add_pd( va0_3b_3, vtmp );
|
||||
@@ -236,12 +239,12 @@ void bli_dgemm_sandybridge_int_8x4
|
||||
// Iteration 0.
|
||||
|
||||
// Load va0_3
|
||||
va0_3 = _mm256_load_pd( a );
|
||||
va0_3 = _mm256_load_pd( a );
|
||||
// Load va4_7
|
||||
va4_7 = _mm256_load_pd( a + 4 );
|
||||
va4_7 = _mm256_load_pd( a + 4 );
|
||||
|
||||
// Load vb (b0,b1,b2,b3)
|
||||
vb = _mm256_load_pd( b );
|
||||
// Load vb (b0,b1,b2,b3)
|
||||
vb = _mm256_load_pd( b );
|
||||
|
||||
vtmp = _mm256_mul_pd( va0_3, vb );
|
||||
va0_3b_0 = _mm256_add_pd( va0_3b_0, vtmp );
|
||||
@@ -250,7 +253,7 @@ void bli_dgemm_sandybridge_int_8x4
|
||||
va4_7b_0 = _mm256_add_pd( va4_7b_0, vtmp );
|
||||
|
||||
// Shuffle vb (b1,b0,b3,b2)
|
||||
vb = _mm256_shuffle_pd( vb, vb, 0x5 );
|
||||
vb = _mm256_shuffle_pd( vb, vb, 0x5 );
|
||||
|
||||
vtmp = _mm256_mul_pd( va0_3, vb );
|
||||
va0_3b_1 = _mm256_add_pd( va0_3b_1, vtmp );
|
||||
@@ -259,7 +262,7 @@ void bli_dgemm_sandybridge_int_8x4
|
||||
va4_7b_1 = _mm256_add_pd( va4_7b_1, vtmp );
|
||||
|
||||
// Permute vb (b3,b2,b1,b0)
|
||||
vb = _mm256_permute2f128_pd( vb, vb, 0x1 );
|
||||
vb = _mm256_permute2f128_pd( vb, vb, 0x1 );
|
||||
|
||||
vtmp = _mm256_mul_pd( va0_3, vb );
|
||||
va0_3b_2 = _mm256_add_pd( va0_3b_2, vtmp );
|
||||
@@ -268,7 +271,7 @@ void bli_dgemm_sandybridge_int_8x4
|
||||
va4_7b_2 = _mm256_add_pd( va4_7b_2, vtmp );
|
||||
|
||||
// Shuffle vb (b3,b2,b1,b0)
|
||||
vb = _mm256_shuffle_pd( vb, vb, 0x5 );
|
||||
vb = _mm256_shuffle_pd( vb, vb, 0x5 );
|
||||
|
||||
vtmp = _mm256_mul_pd( va0_3, vb );
|
||||
va0_3b_3 = _mm256_add_pd( va0_3b_3, vtmp );
|
||||
@@ -309,12 +312,72 @@ void bli_dgemm_sandybridge_int_8x4
|
||||
va4_7b1 = _mm256_permute2f128_pd( vtmpa_4_7b_1, vtmpa_4_7b_3, 0x30 );
|
||||
va4_7b2 = _mm256_permute2f128_pd( vtmpa_4_7b_3, vtmpa_4_7b_1, 0x30 );
|
||||
|
||||
if( rs_c == 1 )
|
||||
__m128d vzero = _mm_setzero_pd( );
|
||||
|
||||
if( _mm_comieq_sd( _mm256_castpd256_pd128(vbeta), vzero ) )
|
||||
{
|
||||
// Calculate address
|
||||
c00 = ( c + 0*rs_c + 0*cs_c );
|
||||
c00 = ( c + 0 + 0*cs_c );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va0_3b0);
|
||||
// Store back to memory
|
||||
_mm256_store_pd( c00, vtmp );
|
||||
|
||||
// Calculate address
|
||||
c40 = ( c + 4 + 0*cs_c );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va4_7b0);
|
||||
// Store back to memory
|
||||
_mm256_store_pd( c40, vtmp );
|
||||
|
||||
// Calculate address
|
||||
c01 = ( c + 0 + 1*cs_c );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va0_3b1);
|
||||
// Store back to memory
|
||||
_mm256_store_pd( c01, vtmp );
|
||||
|
||||
// Calculate address
|
||||
c41 = ( c + 4 + 1*cs_c );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va4_7b1);
|
||||
// Store back to memory
|
||||
_mm256_store_pd( c41, vtmp );
|
||||
|
||||
// Calculate address
|
||||
c02 = ( c + 0 + 2*cs_c );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va0_3b2);
|
||||
// Store back to memory
|
||||
_mm256_store_pd( c02, vtmp );
|
||||
|
||||
// Calculate address
|
||||
c42 = ( c + 4 + 2*cs_c );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va4_7b2);
|
||||
// Store back to memory
|
||||
_mm256_store_pd( c42, vtmp );
|
||||
|
||||
// Calculate address
|
||||
c03 = ( c + 0 + 3*cs_c );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va0_3b3);
|
||||
// Store back to memory
|
||||
_mm256_store_pd( c03, vtmp );
|
||||
|
||||
// Calculate address
|
||||
c43 = ( c + 4 + 3*cs_c );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va4_7b3);
|
||||
// Store back to memory
|
||||
_mm256_store_pd( c43, vtmp );
|
||||
}
|
||||
else
|
||||
{
|
||||
// Calculate address
|
||||
c00 = ( c + 0 + 0*cs_c );
|
||||
// Load
|
||||
//vc0_3_0 = _mm256_load_pd( c + 0*rs_c + 0*cs_c );
|
||||
//vc0_3_0 = _mm256_load_pd( c + 0 + 0*cs_c );
|
||||
vc0_3_0 = _mm256_load_pd( c00 );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va0_3b0);
|
||||
@@ -324,11 +387,11 @@ void bli_dgemm_sandybridge_int_8x4
|
||||
vc0_3_0 = _mm256_add_pd( vc0_3_0, vtmp );
|
||||
// Store back to memory
|
||||
_mm256_store_pd( c00, vc0_3_0 );
|
||||
|
||||
|
||||
// Calculate address
|
||||
c40 = ( c + 4*rs_c + 0*cs_c );
|
||||
c40 = ( c + 4 + 0*cs_c );
|
||||
// Load
|
||||
//vc4_7_0 = _mm256_load_pd( c + 4*rs_c + 0*cs_c );
|
||||
//vc4_7_0 = _mm256_load_pd( c + 4 + 0*cs_c );
|
||||
vc4_7_0 = _mm256_load_pd( c40 );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va4_7b0);
|
||||
@@ -338,11 +401,11 @@ void bli_dgemm_sandybridge_int_8x4
|
||||
vc4_7_0 = _mm256_add_pd( vc4_7_0, vtmp );
|
||||
// Store back to memory
|
||||
_mm256_store_pd( c40, vc4_7_0 );
|
||||
|
||||
|
||||
// Calculate address
|
||||
c01 = ( c + 0*rs_c + 1*cs_c );
|
||||
c01 = ( c + 0 + 1*cs_c );
|
||||
// Load
|
||||
//vc0_3_1 = _mm256_load_pd( c + 0*rs_c + 1*cs_c );
|
||||
//vc0_3_1 = _mm256_load_pd( c + 0 + 1*cs_c );
|
||||
vc0_3_1 = _mm256_load_pd( c01 );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va0_3b1);
|
||||
@@ -352,11 +415,11 @@ void bli_dgemm_sandybridge_int_8x4
|
||||
vc0_3_1 = _mm256_add_pd( vc0_3_1, vtmp );
|
||||
// Store back to memory
|
||||
_mm256_store_pd( c01, vc0_3_1 );
|
||||
|
||||
|
||||
// Calculate address
|
||||
c41 = ( c + 4*rs_c + 1*cs_c );
|
||||
c41 = ( c + 4 + 1*cs_c );
|
||||
// Load
|
||||
//vc4_7_1 = _mm256_load_pd( c + 4*rs_c + 1*cs_c );
|
||||
//vc4_7_1 = _mm256_load_pd( c + 4 + 1*cs_c );
|
||||
vc4_7_1 = _mm256_load_pd( c41 );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va4_7b1);
|
||||
@@ -366,11 +429,11 @@ void bli_dgemm_sandybridge_int_8x4
|
||||
vc4_7_1 = _mm256_add_pd( vc4_7_1, vtmp );
|
||||
// Store back to memory
|
||||
_mm256_store_pd( c41, vc4_7_1 );
|
||||
|
||||
|
||||
// Calculate address
|
||||
c02 = ( c + 0*rs_c + 2*cs_c );
|
||||
c02 = ( c + 0 + 2*cs_c );
|
||||
// Load
|
||||
//vc0_3_2 = _mm256_load_pd( c + 0*rs_c + 2*cs_c );
|
||||
//vc0_3_2 = _mm256_load_pd( c + 0 + 2*cs_c );
|
||||
vc0_3_2 = _mm256_load_pd( c02 );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va0_3b2);
|
||||
@@ -380,11 +443,11 @@ void bli_dgemm_sandybridge_int_8x4
|
||||
vc0_3_2 = _mm256_add_pd( vc0_3_2, vtmp );
|
||||
// Store back to memory
|
||||
_mm256_store_pd( c02, vc0_3_2 );
|
||||
|
||||
|
||||
// Calculate address
|
||||
c42 = ( c + 4*rs_c + 2*cs_c );
|
||||
c42 = ( c + 4 + 2*cs_c );
|
||||
// Load
|
||||
//vc4_7_2 = _mm256_load_pd( c + 4*rs_c + 2*cs_c );
|
||||
//vc4_7_2 = _mm256_load_pd( c + 4 + 2*cs_c );
|
||||
vc4_7_2 = _mm256_load_pd( c42 );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va4_7b2);
|
||||
@@ -394,11 +457,11 @@ void bli_dgemm_sandybridge_int_8x4
|
||||
vc4_7_2 = _mm256_add_pd( vc4_7_2, vtmp );
|
||||
// Store back to memory
|
||||
_mm256_store_pd( c42, vc4_7_2 );
|
||||
|
||||
|
||||
// Calculate address
|
||||
c03 = ( c + 0*rs_c + 3*cs_c );
|
||||
c03 = ( c + 0 + 3*cs_c );
|
||||
// Load
|
||||
//vc0_3_3 = _mm256_load_pd( c + 0*rs_c + 3*cs_c );
|
||||
//vc0_3_3 = _mm256_load_pd( c + 0 + 3*cs_c );
|
||||
vc0_3_3 = _mm256_load_pd( c03 );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va0_3b3);
|
||||
@@ -408,11 +471,11 @@ void bli_dgemm_sandybridge_int_8x4
|
||||
vc0_3_3 = _mm256_add_pd( vc0_3_3, vtmp );
|
||||
// Store back to memory
|
||||
_mm256_store_pd( c03, vc0_3_3 );
|
||||
|
||||
|
||||
// Calculate address
|
||||
c43 = ( c + 4*rs_c + 3*cs_c );
|
||||
c43 = ( c + 4 + 3*cs_c );
|
||||
// Load
|
||||
//vc4_7_3 = _mm256_load_pd( c + 4*rs_c + 3*cs_c );
|
||||
//vc4_7_3 = _mm256_load_pd( c + 4 + 3*cs_c );
|
||||
vc4_7_3 = _mm256_load_pd( c43 );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va4_7b3);
|
||||
@@ -422,211 +485,9 @@ void bli_dgemm_sandybridge_int_8x4
|
||||
vc4_7_3 = _mm256_add_pd( vc4_7_3, vtmp );
|
||||
// Store back to memory
|
||||
_mm256_store_pd( c43, vc4_7_3 );
|
||||
|
||||
}
|
||||
else
|
||||
{
|
||||
// Calculate address
|
||||
c00 = ( c + 0*rs_c + 0*cs_c );
|
||||
// Load
|
||||
//vc0_3_0 = _mm256_load_pd( c + 0*rs_c + 0*cs_c );
|
||||
vc0_3_0 = _mm256_set_pd( *(c + 3*rs_c + 0*cs_c ),
|
||||
*(c + 2*rs_c + 0*cs_c ),
|
||||
*(c + 1*rs_c + 0*cs_c ),
|
||||
*(c + 0*rs_c + 0*cs_c ) );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va0_3b0);
|
||||
// Scale by beta
|
||||
vc0_3_0 = _mm256_mul_pd( vbeta, vc0_3_0 );
|
||||
// Add gemm result
|
||||
vc0_3_0 = _mm256_add_pd( vc0_3_0, vtmp );
|
||||
// Store back to memory
|
||||
//_mm256_store_pd( c00, vc0_3_0 );
|
||||
|
||||
aa = _mm256_extractf128_pd( vc0_3_0, 0 ) ;
|
||||
bb = _mm256_extractf128_pd( vc0_3_0, 1 ) ;
|
||||
|
||||
_mm_storel_pd( c + 0*rs_c + 0*cs_c, aa );
|
||||
_mm_storeh_pd( c + 1*rs_c + 0*cs_c, aa );
|
||||
_mm_storel_pd( c + 2*rs_c + 0*cs_c, bb );
|
||||
_mm_storeh_pd( c + 3*rs_c + 0*cs_c, bb );
|
||||
|
||||
// Calculate address
|
||||
c40 = ( c + 4*rs_c + 0*cs_c );
|
||||
// Load
|
||||
//vc4_7_0 = _mm256_load_pd( c + 4*rs_c + 0*cs_c );
|
||||
vc4_7_0 = _mm256_set_pd( *(c + 7*rs_c + 0*cs_c ),
|
||||
*(c + 6*rs_c + 0*cs_c ),
|
||||
*(c + 5*rs_c + 0*cs_c ),
|
||||
*(c + 4*rs_c + 0*cs_c ) );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va4_7b0);
|
||||
// Scale by beta
|
||||
vc4_7_0 = _mm256_mul_pd( vbeta, vc4_7_0 );
|
||||
// Add gemm result
|
||||
vc4_7_0 = _mm256_add_pd( vc4_7_0, vtmp );
|
||||
// Store back to memory
|
||||
//_mm256_store_pd( c40, vc4_7_0 );
|
||||
|
||||
aa = _mm256_extractf128_pd( vc4_7_0, 0 ) ;
|
||||
bb = _mm256_extractf128_pd( vc4_7_0, 1 ) ;
|
||||
|
||||
_mm_storel_pd( c + 4*rs_c + 0*cs_c, aa );
|
||||
_mm_storeh_pd( c + 5*rs_c + 0*cs_c, aa );
|
||||
_mm_storel_pd( c + 6*rs_c + 0*cs_c, bb );
|
||||
_mm_storeh_pd( c + 7*rs_c + 0*cs_c, bb );
|
||||
|
||||
// Calculate address
|
||||
c01 = ( c + 0*rs_c + 1*cs_c );
|
||||
// Load
|
||||
//vc0_3_1 = _mm256_load_pd( c + 0*rs_c + 1*cs_c );
|
||||
vc0_3_1 = _mm256_set_pd( *(c + 3*rs_c + 1*cs_c ),
|
||||
*(c + 2*rs_c + 1*cs_c ),
|
||||
*(c + 1*rs_c + 1*cs_c ),
|
||||
*(c + 0*rs_c + 1*cs_c ) );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va0_3b1);
|
||||
// Scale by beta
|
||||
vc0_3_1 = _mm256_mul_pd( vbeta, vc0_3_1 );
|
||||
// Add gemm result
|
||||
vc0_3_1 = _mm256_add_pd( vc0_3_1, vtmp );
|
||||
// Store back to memory
|
||||
//_mm256_store_pd( c01, vc0_3_1 );
|
||||
|
||||
aa = _mm256_extractf128_pd( vc0_3_1, 0 ) ;
|
||||
bb = _mm256_extractf128_pd( vc0_3_1, 1 ) ;
|
||||
|
||||
_mm_storel_pd( c + 0*rs_c + 1*cs_c, aa );
|
||||
_mm_storeh_pd( c + 1*rs_c + 1*cs_c, aa );
|
||||
_mm_storel_pd( c + 2*rs_c + 1*cs_c, bb );
|
||||
_mm_storeh_pd( c + 3*rs_c + 1*cs_c, bb );
|
||||
|
||||
// Calculate address
|
||||
c41 = ( c + 4*rs_c + 1*cs_c );
|
||||
// Load
|
||||
//vc4_7_1 = _mm256_load_pd( c + 4*rs_c + 1*cs_c );
|
||||
vc4_7_1 = _mm256_set_pd( *(c + 7*rs_c + 1*cs_c ),
|
||||
*(c + 6*rs_c + 1*cs_c ),
|
||||
*(c + 5*rs_c + 1*cs_c ),
|
||||
*(c + 4*rs_c + 1*cs_c ) );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va4_7b1);
|
||||
// Scale by beta
|
||||
vc4_7_1 = _mm256_mul_pd( vbeta, vc4_7_1 );
|
||||
// Add gemm result
|
||||
vc4_7_1 = _mm256_add_pd( vc4_7_1, vtmp );
|
||||
// Store back to memory
|
||||
//_mm256_store_pd( c41, vc4_7_1 );
|
||||
|
||||
aa = _mm256_extractf128_pd( vc4_7_1, 0 ) ;
|
||||
bb = _mm256_extractf128_pd( vc4_7_1, 1 ) ;
|
||||
|
||||
_mm_storel_pd( c + 4*rs_c + 1*cs_c, aa );
|
||||
_mm_storeh_pd( c + 5*rs_c + 1*cs_c, aa );
|
||||
_mm_storel_pd( c + 6*rs_c + 1*cs_c, bb );
|
||||
_mm_storeh_pd( c + 7*rs_c + 1*cs_c, bb );
|
||||
|
||||
// Calculate address
|
||||
c02 = ( c + 0*rs_c + 2*cs_c );
|
||||
// Load
|
||||
//vc0_3_2 = _mm256_load_pd( c + 0*rs_c + 2*cs_c );
|
||||
vc0_3_2 = _mm256_set_pd( *(c + 3*rs_c + 2*cs_c ),
|
||||
*(c + 2*rs_c + 2*cs_c ),
|
||||
*(c + 1*rs_c + 2*cs_c ),
|
||||
*(c + 0*rs_c + 2*cs_c ) );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va0_3b2);
|
||||
// Scale by beta
|
||||
vc0_3_2 = _mm256_mul_pd( vbeta, vc0_3_2 );
|
||||
// Add gemm result
|
||||
vc0_3_2 = _mm256_add_pd( vc0_3_2, vtmp );
|
||||
// Store back to memory
|
||||
//_mm256_store_pd( c02, vc0_3_2 );
|
||||
|
||||
aa = _mm256_extractf128_pd( vc0_3_2, 0 ) ;
|
||||
bb = _mm256_extractf128_pd( vc0_3_2, 1 ) ;
|
||||
|
||||
_mm_storel_pd( c + 0*rs_c + 2*cs_c, aa );
|
||||
_mm_storeh_pd( c + 1*rs_c + 2*cs_c, aa );
|
||||
_mm_storel_pd( c + 2*rs_c + 2*cs_c, bb );
|
||||
_mm_storeh_pd( c + 3*rs_c + 2*cs_c, bb );
|
||||
|
||||
// Calculate address
|
||||
c42 = ( c + 4*rs_c + 2*cs_c );
|
||||
// Load
|
||||
//vc4_7_2 = _mm256_load_pd( c + 4*rs_c + 2*cs_c );
|
||||
vc4_7_2 = _mm256_set_pd( *(c + 7*rs_c + 2*cs_c ),
|
||||
*(c + 6*rs_c + 2*cs_c ),
|
||||
*(c + 5*rs_c + 2*cs_c ),
|
||||
*(c + 4*rs_c + 2*cs_c ) );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va4_7b2);
|
||||
// Scale by beta
|
||||
vc4_7_2 = _mm256_mul_pd( vbeta, vc4_7_2 );
|
||||
// Add gemm result
|
||||
vc4_7_2 = _mm256_add_pd( vc4_7_2, vtmp );
|
||||
// Store back to memory
|
||||
//_mm256_store_pd( c42, vc4_7_2 );
|
||||
|
||||
aa = _mm256_extractf128_pd( vc4_7_2, 0 ) ;
|
||||
bb = _mm256_extractf128_pd( vc4_7_2, 1 ) ;
|
||||
|
||||
_mm_storel_pd( c + 4*rs_c + 2*cs_c, aa );
|
||||
_mm_storeh_pd( c + 5*rs_c + 2*cs_c, aa );
|
||||
_mm_storel_pd( c + 6*rs_c + 2*cs_c, bb );
|
||||
_mm_storeh_pd( c + 7*rs_c + 2*cs_c, bb );
|
||||
|
||||
// Calculate address
|
||||
c03 = ( c + 0*rs_c + 3*cs_c );
|
||||
// Load
|
||||
//vc0_3_3 = _mm256_load_pd( c + 0*rs_c + 3*cs_c );
|
||||
vc0_3_3 = _mm256_set_pd( *(c + 3*rs_c + 3*cs_c ),
|
||||
*(c + 2*rs_c + 3*cs_c ),
|
||||
*(c + 1*rs_c + 3*cs_c ),
|
||||
*(c + 0*rs_c + 3*cs_c ) );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va0_3b3);
|
||||
// Scale by beta
|
||||
vc0_3_3 = _mm256_mul_pd( vbeta, vc0_3_3 );
|
||||
// Add gemm result
|
||||
vc0_3_3 = _mm256_add_pd( vc0_3_3, vtmp );
|
||||
// Store back to memory
|
||||
//_mm256_store_pd( c03, vc0_3_3 );
|
||||
|
||||
aa = _mm256_extractf128_pd( vc0_3_3, 0 ) ;
|
||||
bb = _mm256_extractf128_pd( vc0_3_3, 1 ) ;
|
||||
|
||||
_mm_storel_pd( c + 0*rs_c + 3*cs_c, aa );
|
||||
_mm_storeh_pd( c + 1*rs_c + 3*cs_c, aa );
|
||||
_mm_storel_pd( c + 2*rs_c + 3*cs_c, bb );
|
||||
_mm_storeh_pd( c + 3*rs_c + 3*cs_c, bb );
|
||||
|
||||
// Calculate address
|
||||
c43 = ( c + 4*rs_c + 3*cs_c );
|
||||
// Load
|
||||
//vc4_7_3 = _mm256_load_pd( c + 4*rs_c + 3*cs_c );
|
||||
vc4_7_3 = _mm256_set_pd( *(c + 7*rs_c + 3*cs_c ),
|
||||
*(c + 6*rs_c + 3*cs_c ),
|
||||
*(c + 5*rs_c + 3*cs_c ),
|
||||
*(c + 4*rs_c + 3*cs_c ) );
|
||||
// Scale by alpha
|
||||
vtmp = _mm256_mul_pd( valpha, va4_7b3);
|
||||
// Scale by beta
|
||||
vc4_7_3 = _mm256_mul_pd( vbeta, vc4_7_3 );
|
||||
// Add gemm result
|
||||
vc4_7_3 = _mm256_add_pd( vc4_7_3, vtmp );
|
||||
// Store back to memory
|
||||
//_mm256_store_pd( c43, vc4_7_3 );
|
||||
|
||||
aa = _mm256_extractf128_pd( vc4_7_3, 0 ) ;
|
||||
bb = _mm256_extractf128_pd( vc4_7_3, 1 ) ;
|
||||
|
||||
_mm_storel_pd( c + 4*rs_c + 3*cs_c, aa );
|
||||
_mm_storeh_pd( c + 5*rs_c + 3*cs_c, aa );
|
||||
_mm_storel_pd( c + 6*rs_c + 3*cs_c, bb );
|
||||
_mm_storeh_pd( c + 7*rs_c + 3*cs_c, bb );
|
||||
}
|
||||
|
||||
GEMM_UKR_FLUSH_CT( d );
|
||||
}
|
||||
|
||||
|
||||
@@ -634,7 +495,9 @@ void bli_dgemm_sandybridge_int_8x4
|
||||
#if 0
|
||||
void bli_cgemm_sandybridge_int_8x4
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
scomplex* restrict alpha,
|
||||
scomplex* restrict a,
|
||||
scomplex* restrict b,
|
||||
@@ -652,7 +515,9 @@ void bli_cgemm_sandybridge_int_8x4
|
||||
#if 0
|
||||
void bli_zgemm_sandybridge_int_4x4
|
||||
(
|
||||
dim_t k0,
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k,
|
||||
dcomplex* restrict alpha,
|
||||
dcomplex* restrict a,
|
||||
dcomplex* restrict b,
|
||||
|
||||
@@ -287,24 +287,28 @@ static int64_t offsets[16] __attribute__((aligned(64))) =
|
||||
{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15};
|
||||
|
||||
|
||||
void bli_dgemm_skx_asm_16x12_l2(
|
||||
dim_t k_,
|
||||
double* restrict alpha,
|
||||
double* restrict a,
|
||||
double* restrict b,
|
||||
double* restrict beta,
|
||||
double* restrict c, inc_t rs_c_, inc_t cs_c_,
|
||||
auxinfo_t* data,
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
void bli_dgemm_skx_asm_16x12_l2
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k_,
|
||||
double* restrict alpha,
|
||||
double* restrict a,
|
||||
double* restrict b,
|
||||
double* restrict beta,
|
||||
double* restrict c, inc_t rs_c_, inc_t cs_c_,
|
||||
auxinfo_t* data,
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
(void)data;
|
||||
(void)cntx;
|
||||
|
||||
const int64_t* offsetPtr = &offsets[0];
|
||||
const int64_t k = k_;
|
||||
const int64_t rs_c = rs_c_;
|
||||
const int64_t cs_c = cs_c_;
|
||||
int64_t k = k_;
|
||||
int64_t rs_c = rs_c_;
|
||||
int64_t cs_c = cs_c_;
|
||||
|
||||
GEMM_UKR_SETUP_CT( d, 16, 12, false );
|
||||
|
||||
BEGIN_ASM()
|
||||
|
||||
@@ -464,62 +468,26 @@ void bli_dgemm_skx_asm_16x12_l2(
|
||||
|
||||
MOV(RAX, VAR(cs_c))
|
||||
LEA(RAX, MEM(,RAX,8))
|
||||
MOV(RBX, VAR(rs_c))
|
||||
LEA(RBX, MEM(,RBX,8))
|
||||
|
||||
// Check if C is column stride. If not, jump to the slow scattered update
|
||||
CMP(RBX, IMM(1))
|
||||
JNE(SCATTEREDUPDATE)
|
||||
VCOMISD(XMM(1), XMM(7))
|
||||
JE(COLSTORBZ)
|
||||
|
||||
VCOMISD(XMM(1), XMM(7))
|
||||
JE(COLSTORBZ)
|
||||
|
||||
UPDATE_C( 8, 9,10,11)
|
||||
UPDATE_C(12,13,14,15)
|
||||
UPDATE_C(16,17,18,19)
|
||||
UPDATE_C(20,21,22,23)
|
||||
UPDATE_C(24,25,26,27)
|
||||
UPDATE_C(28,29,30,31)
|
||||
|
||||
JMP(END)
|
||||
LABEL(COLSTORBZ)
|
||||
|
||||
UPDATE_C_BZ( 8, 9,10,11)
|
||||
UPDATE_C_BZ(12,13,14,15)
|
||||
UPDATE_C_BZ(16,17,18,19)
|
||||
UPDATE_C_BZ(20,21,22,23)
|
||||
UPDATE_C_BZ(24,25,26,27)
|
||||
UPDATE_C_BZ(28,29,30,31)
|
||||
UPDATE_C( 8, 9,10,11)
|
||||
UPDATE_C(12,13,14,15)
|
||||
UPDATE_C(16,17,18,19)
|
||||
UPDATE_C(20,21,22,23)
|
||||
UPDATE_C(24,25,26,27)
|
||||
UPDATE_C(28,29,30,31)
|
||||
|
||||
JMP(END)
|
||||
LABEL(SCATTEREDUPDATE)
|
||||
LABEL(COLSTORBZ)
|
||||
|
||||
MOV(RDI, VAR(offsetPtr))
|
||||
VMOVDQA64(ZMM(2), MEM(RDI,0*64))
|
||||
VMOVDQA64(ZMM(3), MEM(RDI,1*64))
|
||||
VPBROADCASTQ(ZMM(6), RBX)
|
||||
VPMULLQ(ZMM(2), ZMM(6), ZMM(2))
|
||||
VPMULLQ(ZMM(3), ZMM(6), ZMM(3))
|
||||
|
||||
VCOMISD(XMM(1), XMM(7))
|
||||
JE(SCATTERBZ)
|
||||
|
||||
UPDATE_C_ROW_SCATTERED( 8, 9,10,11)
|
||||
UPDATE_C_ROW_SCATTERED(12,13,14,15)
|
||||
UPDATE_C_ROW_SCATTERED(16,17,18,19)
|
||||
UPDATE_C_ROW_SCATTERED(20,21,22,23)
|
||||
UPDATE_C_ROW_SCATTERED(24,25,26,27)
|
||||
UPDATE_C_ROW_SCATTERED(28,29,30,31)
|
||||
|
||||
JMP(END)
|
||||
LABEL(SCATTERBZ)
|
||||
|
||||
UPDATE_C_BZ_ROW_SCATTERED( 8, 9,10,11)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(12,13,14,15)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(16,17,18,19)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(20,21,22,23)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(24,25,26,27)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(28,29,30,31)
|
||||
UPDATE_C_BZ( 8, 9,10,11)
|
||||
UPDATE_C_BZ(12,13,14,15)
|
||||
UPDATE_C_BZ(16,17,18,19)
|
||||
UPDATE_C_BZ(20,21,22,23)
|
||||
UPDATE_C_BZ(24,25,26,27)
|
||||
UPDATE_C_BZ(28,29,30,31)
|
||||
|
||||
LABEL(END)
|
||||
|
||||
@@ -535,8 +503,7 @@ void bli_dgemm_skx_asm_16x12_l2(
|
||||
[beta] "m" (beta),
|
||||
[c] "m" (c),
|
||||
[rs_c] "m" (rs_c),
|
||||
[cs_c] "m" (cs_c),
|
||||
[offsetPtr] "m" (offsetPtr)
|
||||
[cs_c] "m" (cs_c)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12",
|
||||
"r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5",
|
||||
@@ -545,4 +512,6 @@ void bli_dgemm_skx_asm_16x12_l2(
|
||||
"zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29",
|
||||
"zmm30", "zmm31", "memory"
|
||||
)
|
||||
|
||||
GEMM_UKR_FLUSH_CT( d );
|
||||
}
|
||||
|
||||
@@ -153,24 +153,28 @@
|
||||
static int64_t offsets[16] __attribute__((aligned(64))) =
|
||||
{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15};
|
||||
|
||||
void bli_dgemm_skx_asm_16x14(
|
||||
dim_t k_,
|
||||
double* restrict alpha,
|
||||
double* restrict a,
|
||||
double* restrict b,
|
||||
double* restrict beta,
|
||||
double* restrict c, inc_t rs_c_, inc_t cs_c_,
|
||||
auxinfo_t* data,
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
void bli_dgemm_skx_asm_16x14
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k_,
|
||||
double* restrict alpha,
|
||||
double* restrict a,
|
||||
double* restrict b,
|
||||
double* restrict beta,
|
||||
double* restrict c, inc_t rs_c_, inc_t cs_c_,
|
||||
auxinfo_t* data,
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
(void)data;
|
||||
(void)cntx;
|
||||
|
||||
const int64_t* offsetPtr = &offsets[0];
|
||||
const int64_t k = k_;
|
||||
const int64_t rs_c = rs_c_*8;
|
||||
const int64_t cs_c = cs_c_*8;
|
||||
int64_t k = k_;
|
||||
int64_t rs_c = rs_c_;
|
||||
int64_t cs_c = cs_c_;
|
||||
|
||||
GEMM_UKR_SETUP_CT( d, 16, 14, false );
|
||||
|
||||
BEGIN_ASM()
|
||||
|
||||
@@ -220,6 +224,8 @@ void bli_dgemm_skx_asm_16x14(
|
||||
|
||||
MOV(R12, VAR(rs_c))
|
||||
MOV(R10, VAR(cs_c))
|
||||
LEA(R12, MEM(,R12,8))
|
||||
LEA(R10, MEM(,R10,8))
|
||||
|
||||
MOV(RDI, RSI)
|
||||
AND(RSI, IMM(3))
|
||||
@@ -320,119 +326,41 @@ void bli_dgemm_skx_asm_16x14(
|
||||
MOV(RAX, R12)
|
||||
MOV(RBX, R10)
|
||||
|
||||
// Check if C is column stride.
|
||||
CMP(RAX, IMM(8))
|
||||
JNE(SCATTEREDUPDATE)
|
||||
VCOMISD(XMM(1), XMM(2))
|
||||
JE(COLSTORBZ)
|
||||
|
||||
VCOMISD(XMM(1), XMM(2))
|
||||
JE(COLSTORBZ)
|
||||
|
||||
UPDATE_C( 4, 5)
|
||||
UPDATE_C( 6, 7)
|
||||
UPDATE_C( 8, 9)
|
||||
UPDATE_C(10,11)
|
||||
UPDATE_C(12,13)
|
||||
UPDATE_C(14,15)
|
||||
UPDATE_C(16,17)
|
||||
UPDATE_C(18,19)
|
||||
UPDATE_C(20,21)
|
||||
UPDATE_C(22,23)
|
||||
UPDATE_C(24,25)
|
||||
UPDATE_C(26,27)
|
||||
UPDATE_C(28,29)
|
||||
UPDATE_C(30,31)
|
||||
|
||||
JMP(END)
|
||||
LABEL(COLSTORBZ)
|
||||
|
||||
UPDATE_C_BZ( 4, 5)
|
||||
UPDATE_C_BZ( 6, 7)
|
||||
UPDATE_C_BZ( 8, 9)
|
||||
UPDATE_C_BZ(10,11)
|
||||
UPDATE_C_BZ(12,13)
|
||||
UPDATE_C_BZ(14,15)
|
||||
UPDATE_C_BZ(16,17)
|
||||
UPDATE_C_BZ(18,19)
|
||||
UPDATE_C_BZ(20,21)
|
||||
UPDATE_C_BZ(22,23)
|
||||
UPDATE_C_BZ(24,25)
|
||||
UPDATE_C_BZ(26,27)
|
||||
UPDATE_C_BZ(28,29)
|
||||
UPDATE_C_BZ(30,31)
|
||||
UPDATE_C( 4, 5)
|
||||
UPDATE_C( 6, 7)
|
||||
UPDATE_C( 8, 9)
|
||||
UPDATE_C(10,11)
|
||||
UPDATE_C(12,13)
|
||||
UPDATE_C(14,15)
|
||||
UPDATE_C(16,17)
|
||||
UPDATE_C(18,19)
|
||||
UPDATE_C(20,21)
|
||||
UPDATE_C(22,23)
|
||||
UPDATE_C(24,25)
|
||||
UPDATE_C(26,27)
|
||||
UPDATE_C(28,29)
|
||||
UPDATE_C(30,31)
|
||||
|
||||
JMP(END)
|
||||
LABEL(SCATTEREDUPDATE)
|
||||
LABEL(COLSTORBZ)
|
||||
|
||||
VMULPD(ZMM( 4), ZMM( 4), ZMM(0))
|
||||
VMULPD(ZMM( 5), ZMM( 5), ZMM(0))
|
||||
VMULPD(ZMM( 6), ZMM( 6), ZMM(0))
|
||||
VMULPD(ZMM( 7), ZMM( 7), ZMM(0))
|
||||
VMULPD(ZMM( 8), ZMM( 8), ZMM(0))
|
||||
VMULPD(ZMM( 9), ZMM( 9), ZMM(0))
|
||||
VMULPD(ZMM(10), ZMM(10), ZMM(0))
|
||||
VMULPD(ZMM(11), ZMM(11), ZMM(0))
|
||||
VMULPD(ZMM(12), ZMM(12), ZMM(0))
|
||||
VMULPD(ZMM(13), ZMM(13), ZMM(0))
|
||||
VMULPD(ZMM(14), ZMM(14), ZMM(0))
|
||||
VMULPD(ZMM(15), ZMM(15), ZMM(0))
|
||||
VMULPD(ZMM(16), ZMM(16), ZMM(0))
|
||||
VMULPD(ZMM(17), ZMM(17), ZMM(0))
|
||||
VMULPD(ZMM(18), ZMM(18), ZMM(0))
|
||||
VMULPD(ZMM(19), ZMM(19), ZMM(0))
|
||||
VMULPD(ZMM(20), ZMM(20), ZMM(0))
|
||||
VMULPD(ZMM(21), ZMM(21), ZMM(0))
|
||||
VMULPD(ZMM(22), ZMM(22), ZMM(0))
|
||||
VMULPD(ZMM(23), ZMM(23), ZMM(0))
|
||||
VMULPD(ZMM(24), ZMM(24), ZMM(0))
|
||||
VMULPD(ZMM(25), ZMM(25), ZMM(0))
|
||||
VMULPD(ZMM(26), ZMM(26), ZMM(0))
|
||||
VMULPD(ZMM(27), ZMM(27), ZMM(0))
|
||||
VMULPD(ZMM(28), ZMM(28), ZMM(0))
|
||||
VMULPD(ZMM(29), ZMM(29), ZMM(0))
|
||||
VMULPD(ZMM(30), ZMM(30), ZMM(0))
|
||||
VMULPD(ZMM(31), ZMM(31), ZMM(0))
|
||||
|
||||
VCOMISD(XMM(1), XMM(2))
|
||||
|
||||
MOV(RDI, VAR(offsetPtr))
|
||||
VPBROADCASTQ(ZMM(0), RAX)
|
||||
VPMULLQ(ZMM(2), ZMM(0), MEM(RDI))
|
||||
VPMULLQ(ZMM(3), ZMM(0), MEM(RDI,64))
|
||||
|
||||
JE(SCATTERBZ)
|
||||
|
||||
UPDATE_C_COL_SCATTERED( 4, 5)
|
||||
UPDATE_C_COL_SCATTERED( 6, 7)
|
||||
UPDATE_C_COL_SCATTERED( 8, 9)
|
||||
UPDATE_C_COL_SCATTERED(10,11)
|
||||
UPDATE_C_COL_SCATTERED(12,13)
|
||||
UPDATE_C_COL_SCATTERED(14,15)
|
||||
UPDATE_C_COL_SCATTERED(16,17)
|
||||
UPDATE_C_COL_SCATTERED(18,19)
|
||||
UPDATE_C_COL_SCATTERED(20,21)
|
||||
UPDATE_C_COL_SCATTERED(22,23)
|
||||
UPDATE_C_COL_SCATTERED(24,25)
|
||||
UPDATE_C_COL_SCATTERED(26,27)
|
||||
UPDATE_C_COL_SCATTERED(28,29)
|
||||
UPDATE_C_COL_SCATTERED(30,31)
|
||||
|
||||
JMP(END)
|
||||
LABEL(SCATTERBZ)
|
||||
|
||||
UPDATE_C_BZ_COL_SCATTERED( 4, 5)
|
||||
UPDATE_C_BZ_COL_SCATTERED( 6, 7)
|
||||
UPDATE_C_BZ_COL_SCATTERED( 8, 9)
|
||||
UPDATE_C_BZ_COL_SCATTERED(10,11)
|
||||
UPDATE_C_BZ_COL_SCATTERED(12,13)
|
||||
UPDATE_C_BZ_COL_SCATTERED(14,15)
|
||||
UPDATE_C_BZ_COL_SCATTERED(16,17)
|
||||
UPDATE_C_BZ_COL_SCATTERED(18,19)
|
||||
UPDATE_C_BZ_COL_SCATTERED(20,21)
|
||||
UPDATE_C_BZ_COL_SCATTERED(22,23)
|
||||
UPDATE_C_BZ_COL_SCATTERED(24,25)
|
||||
UPDATE_C_BZ_COL_SCATTERED(26,27)
|
||||
UPDATE_C_BZ_COL_SCATTERED(28,29)
|
||||
UPDATE_C_BZ_COL_SCATTERED(30,31)
|
||||
UPDATE_C_BZ( 4, 5)
|
||||
UPDATE_C_BZ( 6, 7)
|
||||
UPDATE_C_BZ( 8, 9)
|
||||
UPDATE_C_BZ(10,11)
|
||||
UPDATE_C_BZ(12,13)
|
||||
UPDATE_C_BZ(14,15)
|
||||
UPDATE_C_BZ(16,17)
|
||||
UPDATE_C_BZ(18,19)
|
||||
UPDATE_C_BZ(20,21)
|
||||
UPDATE_C_BZ(22,23)
|
||||
UPDATE_C_BZ(24,25)
|
||||
UPDATE_C_BZ(26,27)
|
||||
UPDATE_C_BZ(28,29)
|
||||
UPDATE_C_BZ(30,31)
|
||||
|
||||
LABEL(END)
|
||||
|
||||
@@ -449,8 +377,7 @@ void bli_dgemm_skx_asm_16x14(
|
||||
[beta] "m" (beta),
|
||||
[c] "m" (c),
|
||||
[rs_c] "m" (rs_c),
|
||||
[cs_c] "m" (cs_c),
|
||||
[offsetPtr] "m" (offsetPtr)
|
||||
[cs_c] "m" (cs_c)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12",
|
||||
"r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5",
|
||||
@@ -459,4 +386,6 @@ void bli_dgemm_skx_asm_16x14(
|
||||
"zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29",
|
||||
"zmm30", "zmm31", "memory"
|
||||
)
|
||||
|
||||
GEMM_UKR_FLUSH_CT( d );
|
||||
}
|
||||
|
||||
@@ -317,24 +317,28 @@ ahead*/
|
||||
static int64_t offsets[16] __attribute__((aligned(64))) =
|
||||
{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15};
|
||||
|
||||
void bli_sgemm_skx_asm_32x12_l2(
|
||||
dim_t k_,
|
||||
float* restrict alpha,
|
||||
float* restrict a,
|
||||
float* restrict b,
|
||||
float* restrict beta,
|
||||
float* restrict c, inc_t rs_c_, inc_t cs_c_,
|
||||
auxinfo_t* data,
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
void bli_sgemm_skx_asm_32x12_l2
|
||||
(
|
||||
dim_t m,
|
||||
dim_t n,
|
||||
dim_t k_,
|
||||
float* restrict alpha,
|
||||
float* restrict a,
|
||||
float* restrict b,
|
||||
float* restrict beta,
|
||||
float* restrict c, inc_t rs_c_, inc_t cs_c_,
|
||||
auxinfo_t* data,
|
||||
cntx_t* restrict cntx
|
||||
)
|
||||
{
|
||||
(void)data;
|
||||
(void)cntx;
|
||||
|
||||
const int64_t* offsetPtr = &offsets[0];
|
||||
const int64_t k = k_;
|
||||
const int64_t rs_c = rs_c_;
|
||||
const int64_t cs_c = cs_c_;
|
||||
int64_t k = k_;
|
||||
int64_t rs_c = rs_c_;
|
||||
int64_t cs_c = cs_c_;
|
||||
|
||||
GEMM_UKR_SETUP_CT( s, 32, 12, false );
|
||||
|
||||
BEGIN_ASM()
|
||||
|
||||
@@ -381,7 +385,7 @@ void bli_sgemm_skx_asm_32x12_l2(
|
||||
#endif
|
||||
|
||||
#ifdef PREFETCH_B_BEFORE
|
||||
/* Prefetching 3 cachlines of B (4 iterations worth of data
|
||||
/* Prefetching 3 cachlines of B (4 iterations worth of data
|
||||
(12 (NR) x 4 (sizeof(float)) x 4 iter /64 = 3 cachelines) */
|
||||
PREFETCH(0, MEM(RBX,0*64))
|
||||
PREFETCH(0, MEM(RBX,1*64))
|
||||
@@ -485,66 +489,26 @@ void bli_sgemm_skx_asm_32x12_l2(
|
||||
|
||||
MOV(RAX, VAR(cs_c))
|
||||
LEA(RAX, MEM(,RAX,4))
|
||||
MOV(RBX, VAR(rs_c))
|
||||
LEA(RBX, MEM(,RBX,4))
|
||||
|
||||
VCOMISS(XMM(1), XMM(7))
|
||||
JE(COLSTORBZ)
|
||||
|
||||
// Check if C is column major (rs_c = 1). If not, jump to the slow scattered update
|
||||
CMP(RBX, IMM(4))
|
||||
JNE(SCATTEREDUPDATE)
|
||||
|
||||
VCOMISS(XMM(1), XMM(7))
|
||||
JE(COLSTORBZ)
|
||||
|
||||
UPDATE_C( 8, 9,10,11)
|
||||
UPDATE_C(12,13,14,15)
|
||||
UPDATE_C(16,17,18,19)
|
||||
UPDATE_C(20,21,22,23)
|
||||
UPDATE_C(24,25,26,27)
|
||||
UPDATE_C(28,29,30,31)
|
||||
|
||||
JMP(END)
|
||||
LABEL(COLSTORBZ)
|
||||
|
||||
UPDATE_C_BZ( 8, 9,10,11)
|
||||
UPDATE_C_BZ(12,13,14,15)
|
||||
UPDATE_C_BZ(16,17,18,19)
|
||||
UPDATE_C_BZ(20,21,22,23)
|
||||
UPDATE_C_BZ(24,25,26,27)
|
||||
UPDATE_C_BZ(28,29,30,31)
|
||||
UPDATE_C( 8, 9,10,11)
|
||||
UPDATE_C(12,13,14,15)
|
||||
UPDATE_C(16,17,18,19)
|
||||
UPDATE_C(20,21,22,23)
|
||||
UPDATE_C(24,25,26,27)
|
||||
UPDATE_C(28,29,30,31)
|
||||
|
||||
JMP(END)
|
||||
LABEL(SCATTEREDUPDATE)
|
||||
LABEL(COLSTORBZ)
|
||||
|
||||
LEA(RDX, MEM(RCX,RBX,8))
|
||||
LEA(RDX, MEM(RDX,RBX,8))
|
||||
|
||||
MOV(RDI, VAR(offsetPtr))
|
||||
VMOVDQA64(ZMM(2), MEM(RDI,0*64))
|
||||
VMOVDQA64(ZMM(3), MEM(RDI,1*64))
|
||||
VPBROADCASTQ(ZMM(6), RBX)
|
||||
VPMULLQ(ZMM(2), ZMM(6), ZMM(2))
|
||||
VPMULLQ(ZMM(3), ZMM(6), ZMM(3))
|
||||
|
||||
VCOMISS(XMM(1), XMM(7))
|
||||
JE(SCATTERBZ)
|
||||
|
||||
UPDATE_C_ROW_SCATTERED( 8, 9,10,11)
|
||||
UPDATE_C_ROW_SCATTERED(12,13,14,15)
|
||||
UPDATE_C_ROW_SCATTERED(16,17,18,19)
|
||||
UPDATE_C_ROW_SCATTERED(20,21,22,23)
|
||||
UPDATE_C_ROW_SCATTERED(24,25,26,27)
|
||||
UPDATE_C_ROW_SCATTERED(28,29,30,31)
|
||||
|
||||
JMP(END)
|
||||
LABEL(SCATTERBZ)
|
||||
|
||||
UPDATE_C_BZ_ROW_SCATTERED( 8, 9,10,11)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(12,13,14,15)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(16,17,18,19)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(20,21,22,23)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(24,25,26,27)
|
||||
UPDATE_C_BZ_ROW_SCATTERED(28,29,30,31)
|
||||
UPDATE_C_BZ( 8, 9,10,11)
|
||||
UPDATE_C_BZ(12,13,14,15)
|
||||
UPDATE_C_BZ(16,17,18,19)
|
||||
UPDATE_C_BZ(20,21,22,23)
|
||||
UPDATE_C_BZ(24,25,26,27)
|
||||
UPDATE_C_BZ(28,29,30,31)
|
||||
|
||||
LABEL(END)
|
||||
|
||||
@@ -560,8 +524,7 @@ void bli_sgemm_skx_asm_32x12_l2(
|
||||
[beta] "m" (beta),
|
||||
[c] "m" (c),
|
||||
[rs_c] "m" (rs_c),
|
||||
[cs_c] "m" (cs_c),
|
||||
[offsetPtr] "m" (offsetPtr)
|
||||
[cs_c] "m" (cs_c)
|
||||
: // register clobber list
|
||||
"rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12",
|
||||
"r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5",
|
||||
@@ -570,4 +533,6 @@ void bli_sgemm_skx_asm_32x12_l2(
|
||||
"zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29",
|
||||
"zmm30", "zmm31", "memory"
|
||||
)
|
||||
|
||||
GEMM_UKR_FLUSH_CT( s );
|
||||
}
|
||||
|
||||
@@ -42,6 +42,8 @@
|
||||
\
|
||||
void PASTEMAC3(ch,opname,arch,suf) \
|
||||
( \
|
||||
dim_t m, \
|
||||
dim_t n, \
|
||||
dim_t k, \
|
||||
ctype* restrict alpha, \
|
||||
ctype* restrict a, \
|
||||
@@ -59,9 +61,6 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
||||
\
|
||||
const inc_t packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \
|
||||
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
|
||||
\
|
||||
const dim_t m = mr; \
|
||||
const dim_t n = nr; \
|
||||
\
|
||||
const inc_t cs_a = packmr; \
|
||||
\
|
||||
|
||||
@@ -87,6 +87,8 @@ PASTEMAC(d,fprintm)( stdout, "gemmtrsm_ukr: b11", mr, 2*nr, \
|
||||
/* upper: b11 = alpha * b11 - a12 * b21; */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
mr, \
|
||||
nr, \
|
||||
k, \
|
||||
minus_one, \
|
||||
a1x, \
|
||||
|
||||
@@ -44,6 +44,8 @@
|
||||
\
|
||||
void PASTEMAC3(ch,opname,arch,suf) \
|
||||
( \
|
||||
dim_t m, \
|
||||
dim_t n, \
|
||||
dim_t k, \
|
||||
ctype* restrict alpha, \
|
||||
ctype* restrict a, \
|
||||
@@ -107,8 +109,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
||||
\
|
||||
if ( PASTEMAC(ch,eq0)( *beta ) ) \
|
||||
{ \
|
||||
for ( dim_t i = 0; i < mr; ++i ) \
|
||||
for ( dim_t j = 0; j < nr; ++j ) \
|
||||
for ( dim_t i = 0; i < m; ++i ) \
|
||||
for ( dim_t j = 0; j < n; ++j ) \
|
||||
PASTEMAC(ch,copys) \
|
||||
( \
|
||||
ab[ i*rs_ab + j*cs_ab ], \
|
||||
@@ -117,8 +119,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
for ( dim_t i = 0; i < mr; ++i ) \
|
||||
for ( dim_t j = 0; j < nr; ++j ) \
|
||||
for ( dim_t i = 0; i < m; ++i ) \
|
||||
for ( dim_t j = 0; j < n; ++j ) \
|
||||
PASTEMAC(ch,xpbys) \
|
||||
( \
|
||||
ab[ i*rs_ab + j*cs_ab ], \
|
||||
@@ -133,8 +135,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
||||
\
|
||||
if ( PASTEMAC(ch,eq0)( *beta ) ) \
|
||||
{ \
|
||||
for ( dim_t j = 0; j < nr; ++j ) \
|
||||
for ( dim_t i = 0; i < mr; ++i ) \
|
||||
for ( dim_t j = 0; j < n; ++j ) \
|
||||
for ( dim_t i = 0; i < m; ++i ) \
|
||||
PASTEMAC(ch,copys) \
|
||||
( \
|
||||
ab[ i*rs_ab + j*cs_ab ], \
|
||||
@@ -143,8 +145,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
for ( dim_t j = 0; j < nr; ++j ) \
|
||||
for ( dim_t i = 0; i < mr; ++i ) \
|
||||
for ( dim_t j = 0; j < n; ++j ) \
|
||||
for ( dim_t i = 0; i < m; ++i ) \
|
||||
PASTEMAC(ch,xpbys) \
|
||||
( \
|
||||
ab[ i*rs_ab + j*cs_ab ], \
|
||||
@@ -171,6 +173,8 @@ GENTFUNC( dcomplex, z, gemm, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, 4, 4 )
|
||||
\
|
||||
void PASTEMAC3(ch,opname,arch,suf) \
|
||||
( \
|
||||
dim_t m, \
|
||||
dim_t n, \
|
||||
dim_t k, \
|
||||
ctype* restrict alpha, \
|
||||
ctype* restrict a, \
|
||||
@@ -188,9 +192,6 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
||||
\
|
||||
const inc_t packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \
|
||||
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
|
||||
\
|
||||
const dim_t m = mr; \
|
||||
const dim_t n = nr; \
|
||||
\
|
||||
const inc_t cs_a = packmr; \
|
||||
\
|
||||
|
||||
@@ -52,6 +52,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
||||
{ \
|
||||
const num_t dt = PASTEMAC(ch,type); \
|
||||
\
|
||||
const inc_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \
|
||||
const inc_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \
|
||||
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
|
||||
\
|
||||
const inc_t rs_b = packnr; \
|
||||
@@ -68,6 +70,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
||||
/* upper: b11 = alpha * b11 - a12 * b21; */ \
|
||||
gemm_ukr \
|
||||
( \
|
||||
mr, \
|
||||
nr, \
|
||||
k, \
|
||||
minus_one, \
|
||||
a1x, \
|
||||
|
||||
@@ -39,6 +39,8 @@
|
||||
\
|
||||
void PASTEMAC3(ch,opname,arch,suf) \
|
||||
( \
|
||||
dim_t m, \
|
||||
dim_t n, \
|
||||
dim_t k, \
|
||||
ctype* restrict alpha, \
|
||||
ctype* restrict a, \
|
||||
@@ -59,6 +61,9 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
||||
\
|
||||
const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \
|
||||
const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \
|
||||
\
|
||||
const dim_t mr_r = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \
|
||||
const dim_t nr_r = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \
|
||||
\
|
||||
const dim_t k2 = 2 * k; \
|
||||
\
|
||||
@@ -118,6 +123,11 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
||||
else if ( bli_is_gen_stored( rs_c, cs_c ) ) using_ct = TRUE; \
|
||||
else using_ct = FALSE; \
|
||||
\
|
||||
\
|
||||
/* If we are not computing a full micro-tile, then we must write to
|
||||
ct and then accumulate to c afterwards. */ \
|
||||
if ( mr != m || nr != n ) using_ct = TRUE; \
|
||||
\
|
||||
\
|
||||
if ( using_ct ) \
|
||||
{ \
|
||||
@@ -149,6 +159,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
||||
/* c = beta * c + alpha_r * a * b; */ \
|
||||
rgemm_ukr \
|
||||
( \
|
||||
mr_r, \
|
||||
nr_r, \
|
||||
k2, \
|
||||
alpha_r, \
|
||||
a_r, \
|
||||
@@ -164,8 +176,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
||||
/* Accumulate the final result in ct back to c. */ \
|
||||
if ( PASTEMAC(ch,eq1)( *beta ) ) \
|
||||
{ \
|
||||
for ( j = 0; j < nr; ++j ) \
|
||||
for ( i = 0; i < mr; ++i ) \
|
||||
for ( j = 0; j < n; ++j ) \
|
||||
for ( i = 0; i < m; ++i ) \
|
||||
{ \
|
||||
PASTEMAC(ch,adds)( *(ct + i*rs_ct + j*cs_ct), \
|
||||
*(c + i*rs_c + j*cs_c ) ); \
|
||||
@@ -173,8 +185,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
||||
} \
|
||||
else if ( PASTEMAC(ch,eq0)( *beta ) ) \
|
||||
{ \
|
||||
for ( j = 0; j < nr; ++j ) \
|
||||
for ( i = 0; i < mr; ++i ) \
|
||||
for ( j = 0; j < n; ++j ) \
|
||||
for ( i = 0; i < m; ++i ) \
|
||||
{ \
|
||||
PASTEMAC(ch,copys)( *(ct + i*rs_ct + j*cs_ct), \
|
||||
*(c + i*rs_c + j*cs_c ) ); \
|
||||
@@ -182,8 +194,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
for ( j = 0; j < nr; ++j ) \
|
||||
for ( i = 0; i < mr; ++i ) \
|
||||
for ( j = 0; j < n; ++j ) \
|
||||
for ( i = 0; i < m; ++i ) \
|
||||
{ \
|
||||
PASTEMAC(ch,xpbys)( *(ct + i*rs_ct + j*cs_ct), \
|
||||
*beta, \
|
||||
@@ -215,6 +227,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
||||
/* c = beta * c + alpha_r * a * b; */ \
|
||||
rgemm_ukr \
|
||||
( \
|
||||
mr_r, \
|
||||
nr_r, \
|
||||
k2, \
|
||||
alpha_r, \
|
||||
a_r, \
|
||||
|
||||
@@ -153,6 +153,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
||||
upper: bt = -1.0 * a12 * b21; */ \
|
||||
rgemm_ukr \
|
||||
( \
|
||||
mr_r, \
|
||||
nr_r, \
|
||||
k2, \
|
||||
minus_one_r, \
|
||||
a1x_r, \
|
||||
|
||||
267
test/syrk_diagonal/complex_math.hpp
Normal file
267
test/syrk_diagonal/complex_math.hpp
Normal file
@@ -0,0 +1,267 @@
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <type_traits>
|
||||
|
||||
#include "blis.h"
|
||||
|
||||
template <typename T>
|
||||
struct is_complex : std::false_type {};
|
||||
|
||||
template <>
|
||||
struct is_complex<scomplex> : std::true_type {};
|
||||
|
||||
template <>
|
||||
struct is_complex<dcomplex> : std::true_type {};
|
||||
|
||||
template <typename T>
|
||||
struct is_real : std::integral_constant<bool,!is_complex<T>::value> {};
|
||||
|
||||
template <typename T> struct make_complex;
|
||||
|
||||
template <> struct make_complex<float > { using type = scomplex; };
|
||||
template <> struct make_complex<double > { using type = dcomplex; };
|
||||
template <> struct make_complex<scomplex> { using type = scomplex; };
|
||||
template <> struct make_complex<dcomplex> { using type = dcomplex; };
|
||||
|
||||
template <typename T>
|
||||
using make_complex_t = typename make_complex<T>::type;
|
||||
|
||||
template <typename T> struct make_real;
|
||||
|
||||
template <> struct make_real<float > { using type = float; };
|
||||
template <> struct make_real<double > { using type = double; };
|
||||
template <> struct make_real<scomplex> { using type = float; };
|
||||
template <> struct make_real<dcomplex> { using type = double; };
|
||||
|
||||
template <typename T>
|
||||
using make_real_t = typename make_real<T>::type;
|
||||
|
||||
template <typename T, bool Cond>
|
||||
struct make_complex_if : std::conditional<Cond,make_complex_t<T>,make_real_t<T>> {};
|
||||
|
||||
template <typename T, bool Cond>
|
||||
using make_complex_if_t = typename make_complex_if<T,Cond>::type;
|
||||
|
||||
template <typename T>
|
||||
struct real_imag_part
|
||||
{
|
||||
real_imag_part& operator=(T) { return *this; }
|
||||
|
||||
operator T() const { return T(); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_arithmetic<typename std::remove_cv<T>::type>::value,T&> real(T& x) { return x; }
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_arithmetic<T>::value,real_imag_part<T>> imag(T x) { return {}; }
|
||||
|
||||
inline float& real(scomplex& x) { return x.real; }
|
||||
|
||||
inline float& imag(scomplex& x) { return x.imag; }
|
||||
|
||||
inline double& real(dcomplex& x) { return x.real; }
|
||||
|
||||
inline double& imag(dcomplex& x) { return x.imag; }
|
||||
|
||||
inline const float& real(const scomplex& x) { return x.real; }
|
||||
|
||||
inline const float& imag(const scomplex& x) { return x.imag; }
|
||||
|
||||
inline const double& real(const dcomplex& x) { return x.real; }
|
||||
|
||||
inline const double& imag(const dcomplex& x) { return x.imag; }
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<is_real<T>::value,T> conj(T x) { return x; }
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<is_complex<T>::value,T> conj(const T& x) { return {x.real, -x.imag}; }
|
||||
|
||||
template <typename T, typename U, typename=void>
|
||||
struct convert_impl;
|
||||
|
||||
template <typename T, typename U>
|
||||
struct convert_impl<T, U, std::enable_if_t<is_real<T>::value && is_real<U>::value>>
|
||||
{
|
||||
void operator()(T x, U& y) const { y = x; }
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
struct convert_impl<T, U, std::enable_if_t<is_real<T>::value && is_complex<U>::value>>
|
||||
{
|
||||
void operator()(T x, U& y) const { y.real = x; y.imag = 0; }
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
struct convert_impl<T, U, std::enable_if_t<is_complex<T>::value && is_real<U>::value>>
|
||||
{
|
||||
void operator()(T x, U& y) const { y = x.real; }
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
struct convert_impl<T, U, std::enable_if_t<is_complex<T>::value && is_complex<U>::value>>
|
||||
{
|
||||
void operator()(T x, U& y) const { y.real = x.real; y.imag = x.imag; }
|
||||
};
|
||||
|
||||
template <typename U, typename T>
|
||||
U convert(T x)
|
||||
{
|
||||
U y;
|
||||
convert_impl<T,U>{}(x,y);
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename U, typename T>
|
||||
auto convert_prec(T x) -> make_complex_if_t<U,is_complex<T>::value>
|
||||
{
|
||||
return convert<make_complex_if_t<U,is_complex<T>::value>>(x);
|
||||
}
|
||||
|
||||
#define COMPLEX_MATH_OPS(rtype, ctype) \
|
||||
\
|
||||
inline bool operator==(rtype x, ctype y) \
|
||||
{ \
|
||||
return x == y.real && y.imag == 0; \
|
||||
} \
|
||||
\
|
||||
inline bool operator==(ctype x, rtype y) \
|
||||
{ \
|
||||
return y == x.real && x.imag == 0; \
|
||||
} \
|
||||
\
|
||||
inline bool operator==(ctype x, ctype y) \
|
||||
{ \
|
||||
return x.real == y.real && \
|
||||
x.imag == y.imag; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator-(ctype x) \
|
||||
{ \
|
||||
return {-x.real, -x.imag}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator+(rtype x, ctype y) \
|
||||
{ \
|
||||
return {x+y.real, y.imag}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator+(ctype x, rtype y) \
|
||||
{ \
|
||||
return {y+x.real, x.imag}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator+(ctype x, ctype y) \
|
||||
{ \
|
||||
return {x.real+y.real, x.imag+y.imag}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator-(rtype x, ctype y) \
|
||||
{ \
|
||||
return {x-y.real, -y.imag}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator-(ctype x, rtype y) \
|
||||
{ \
|
||||
return {x.real-y, x.imag}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator-(ctype x, ctype y) \
|
||||
{ \
|
||||
return {x.real-y.real, x.imag-y.imag}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator*(rtype x, ctype y) \
|
||||
{ \
|
||||
return {x*y.real, x*y.imag}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator*(ctype x, rtype y) \
|
||||
{ \
|
||||
return {y*x.real, y*x.imag}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator*(ctype x, ctype y) \
|
||||
{ \
|
||||
return {x.real*y.real - x.imag*y.imag, \
|
||||
x.real*y.imag + x.imag*y.real}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator/(rtype x, ctype y) \
|
||||
{ \
|
||||
auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \
|
||||
auto n = std::ilogb(scale); \
|
||||
auto yrs = std::scalbn(y.real, -n); \
|
||||
auto yis = std::scalbn(y.imag, -n); \
|
||||
auto denom = y.real*yrs + y.imag*yis; \
|
||||
return {x*yrs/denom, -x*yis/denom}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator/(ctype x, rtype y) \
|
||||
{ \
|
||||
return {x.real/y, x.imag/y}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator/(ctype x, ctype y) \
|
||||
{ \
|
||||
auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \
|
||||
auto n = std::ilogb(scale); \
|
||||
auto yrs = std::scalbn(y.real, -n); \
|
||||
auto yis = std::scalbn(y.imag, -n); \
|
||||
auto denom = y.real*yrs + y.imag*yis; \
|
||||
return {(x.real*yrs + x.imag*yis)/denom, \
|
||||
(x.imag*yrs - x.real*yis)/denom}; \
|
||||
} \
|
||||
\
|
||||
inline ctype& operator+=(ctype& x, rtype y) \
|
||||
{ \
|
||||
x.real += y; \
|
||||
return x; \
|
||||
} \
|
||||
\
|
||||
inline ctype& operator+=(ctype& x, ctype y) \
|
||||
{ \
|
||||
x.real += y.real; x.imag += y.imag; \
|
||||
return x; \
|
||||
} \
|
||||
\
|
||||
inline ctype& operator-=(ctype& x, rtype y) \
|
||||
{ \
|
||||
x.real -= y; \
|
||||
return x; \
|
||||
} \
|
||||
\
|
||||
inline ctype& operator-=(ctype& x, ctype y) \
|
||||
{ \
|
||||
x.real -= y.real; x.imag -= y.imag; \
|
||||
return x; \
|
||||
} \
|
||||
\
|
||||
inline ctype& operator*=(ctype& x, rtype y) \
|
||||
{ \
|
||||
x.real *= y; x.imag *= y; \
|
||||
return x; \
|
||||
} \
|
||||
\
|
||||
inline ctype& operator*=(ctype& x, ctype y) \
|
||||
{ \
|
||||
x = x * y; \
|
||||
return x; \
|
||||
} \
|
||||
\
|
||||
inline ctype& operator/=(ctype& x, rtype y) \
|
||||
{ \
|
||||
x.real /= y; x.imag /= y; \
|
||||
return x; \
|
||||
} \
|
||||
\
|
||||
inline ctype& operator/=(ctype& x, ctype y) \
|
||||
{ \
|
||||
x = x / y; \
|
||||
return x; \
|
||||
}
|
||||
|
||||
COMPLEX_MATH_OPS(float, scomplex);
|
||||
COMPLEX_MATH_OPS(double, dcomplex);
|
||||
|
||||
186
test/syrk_diagonal/syrk_diagonal_example.c
Normal file
186
test/syrk_diagonal/syrk_diagonal_example.c
Normal file
@@ -0,0 +1,186 @@
|
||||
#include "syrk_diagonal_ref.h"
|
||||
|
||||
/*
|
||||
* Structure which includes all additional information beyond what is
|
||||
* already stored in the obj_t structure.
|
||||
*
|
||||
* This structure is **read-only** during the operation!
|
||||
*/
|
||||
typedef struct packm_diag_params_t
|
||||
{
|
||||
packm_blk_var1_params_t super;
|
||||
void* d;
|
||||
inc_t incd;
|
||||
} packm_diag_params_t;
|
||||
|
||||
/*
|
||||
* Declare the pack kernel type and set up and array of
|
||||
* packing kernels, one for each data type.
|
||||
*/
|
||||
#undef GENTFUNC
|
||||
#define GENTFUNC(ctype,ch,op) \
|
||||
void PASTEMAC(ch,op) \
|
||||
( \
|
||||
struc_t struca, \
|
||||
diag_t diaga, \
|
||||
uplo_t uploa, \
|
||||
conj_t conja, \
|
||||
pack_t schema, \
|
||||
bool invdiag, \
|
||||
dim_t panel_dim, \
|
||||
dim_t panel_len, \
|
||||
dim_t panel_dim_max, \
|
||||
dim_t panel_len_max, \
|
||||
dim_t panel_dim_off, \
|
||||
dim_t panel_len_off, \
|
||||
void* restrict kappa, \
|
||||
void* restrict a, inc_t inca, inc_t lda, \
|
||||
void* restrict p, inc_t ldp, \
|
||||
inc_t is_p, \
|
||||
cntx_t* cntx, \
|
||||
void* params \
|
||||
) \
|
||||
{ \
|
||||
packm_diag_params_t* params_cast = params; \
|
||||
ctype* restrict a_cast = a; \
|
||||
ctype* restrict p_cast = p; \
|
||||
ctype* restrict d_cast = params_cast->d; \
|
||||
inc_t incd = params_cast->incd; \
|
||||
ctype kappa_cast = *( ctype* )kappa; \
|
||||
\
|
||||
if ( schema != BLIS_PACKED_ROW_PANELS && \
|
||||
schema != BLIS_PACKED_COL_PANELS ) \
|
||||
bli_abort(); \
|
||||
\
|
||||
/* Apply the offset */ \
|
||||
d_cast += panel_len_off * incd; \
|
||||
\
|
||||
if ( conja ) \
|
||||
{ \
|
||||
for ( dim_t j = 0; j < panel_len; j++ ) \
|
||||
{ \
|
||||
ctype kappa_d; \
|
||||
PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \
|
||||
\
|
||||
for (dim_t i = 0;i < panel_dim;i++) \
|
||||
PASTEMAC(ch,scal2js)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \
|
||||
\
|
||||
for (dim_t i = panel_dim;i < panel_dim_max;i++) \
|
||||
PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
for ( dim_t j = 0; j < panel_len; j++ ) \
|
||||
{ \
|
||||
ctype kappa_d; \
|
||||
PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \
|
||||
\
|
||||
for (dim_t i = 0;i < panel_dim;i++) \
|
||||
PASTEMAC(ch,scal2s)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \
|
||||
\
|
||||
for (dim_t i = panel_dim;i < panel_dim_max;i++) \
|
||||
PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
for (dim_t j = panel_len;j < panel_len_max;j++) \
|
||||
for (dim_t i = 0;i < panel_dim_max;i++) \
|
||||
PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \
|
||||
}
|
||||
|
||||
INSERT_GENTFUNC_BASIC0(packm_diag_ukr);
|
||||
|
||||
static packm_ker_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr );
|
||||
|
||||
/*
|
||||
* Modify the object A to include information about the diagonal D,
|
||||
* and imbue it with special function pointers which will take care
|
||||
* of the actual work of forming (D * A^T)
|
||||
*/
|
||||
void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a )
|
||||
{
|
||||
memset( params, 0, sizeof(*params) );
|
||||
|
||||
// Assumes D is a column vector
|
||||
params->d = bli_obj_buffer_at_off( d );
|
||||
params->incd = bli_obj_row_stride( d );
|
||||
|
||||
for ( int i = BLIS_DT_LO; i <= BLIS_DT_HI; i++ )
|
||||
params->super.ukr_fn[i][i] = packm_diag_ukrs[i];
|
||||
|
||||
// Attach the parameters to the A object.
|
||||
bli_obj_set_pack_params( params, a );
|
||||
}
|
||||
|
||||
/*
|
||||
* Implements C := alpha * A * D * A^T + beta * C
|
||||
*
|
||||
* where D is a diagonal matrix with elements taken from the "d" vector.
|
||||
*/
|
||||
void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c )
|
||||
{
|
||||
obj_t ad; // this is (D * A^T)
|
||||
packm_diag_params_t params;
|
||||
|
||||
bli_obj_alias_to( a, &ad );
|
||||
bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T
|
||||
attach_diagonal_factor( ¶ms, d, &ad );
|
||||
|
||||
// Does C := alpha * A * B + beta * C using B = (D + A^T)
|
||||
bli_gemmtnat( alpha, a, &ad, beta, c, NULL, NULL );
|
||||
}
|
||||
|
||||
int main( void )
|
||||
{
|
||||
obj_t a;
|
||||
obj_t d;
|
||||
obj_t c;
|
||||
obj_t c_copy;
|
||||
obj_t norm;
|
||||
|
||||
dim_t m = 10;
|
||||
dim_t k = 10;
|
||||
|
||||
for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ )
|
||||
for ( int upper = 0; upper <= 1; upper++ )
|
||||
for ( int transa = 0; transa <= 1; transa++ )
|
||||
for ( int transc = 0; transc <= 1; transc++ )
|
||||
{
|
||||
num_t dt = dt_;
|
||||
uplo_t uplo = upper ? BLIS_UPPER : BLIS_LOWER;
|
||||
|
||||
bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a );
|
||||
bli_obj_create( dt, k, 1, 1, 1, &d );
|
||||
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c );
|
||||
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy );
|
||||
bli_obj_set_struc( BLIS_SYMMETRIC , &c );
|
||||
bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy );
|
||||
bli_obj_set_uplo( uplo , &c );
|
||||
bli_obj_set_uplo( uplo , &c_copy );
|
||||
bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm );
|
||||
|
||||
bli_randm( &a );
|
||||
bli_randm( &d );
|
||||
bli_randm( &c );
|
||||
bli_copym( &c, &c_copy );
|
||||
|
||||
syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c );
|
||||
syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy );
|
||||
|
||||
bli_subm( &c_copy, &c );
|
||||
bli_normfm( &c, &norm );
|
||||
|
||||
double normr, normi;
|
||||
bli_getsc( &norm, &normr, &normi );
|
||||
|
||||
printf( "dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n",
|
||||
dt, upper, transa, transc, normr );
|
||||
|
||||
bli_obj_free( &a );
|
||||
bli_obj_free( &d );
|
||||
bli_obj_free( &c );
|
||||
bli_obj_free( &c_copy );
|
||||
bli_obj_free( &norm );
|
||||
}
|
||||
}
|
||||
220
test/syrk_diagonal/syrk_diagonal_example.cxx
Normal file
220
test/syrk_diagonal/syrk_diagonal_example.cxx
Normal file
@@ -0,0 +1,220 @@
|
||||
#include "syrk_diagonal_ref.h"
|
||||
|
||||
/*
|
||||
* Forward-declare the pack kernel type and set up and array of
|
||||
* packing kernels, one for each data type.
|
||||
*/
|
||||
template <typename T>
|
||||
void packm_diag_ukr
|
||||
(
|
||||
struc_t /*struca*/,
|
||||
diag_t /*diaga*/,
|
||||
uplo_t /*uploa*/,
|
||||
conj_t conja,
|
||||
pack_t schema,
|
||||
bool /*invdiag*/,
|
||||
dim_t panel_dim,
|
||||
dim_t panel_len,
|
||||
dim_t panel_dim_max,
|
||||
dim_t panel_len_max,
|
||||
dim_t /*panel_dim_off*/,
|
||||
dim_t panel_len_off,
|
||||
void* restrict kappa,
|
||||
void* restrict a, inc_t inca, inc_t lda,
|
||||
void* restrict p, inc_t ldp,
|
||||
inc_t /*is_p*/,
|
||||
cntx_t* /*cntx*/,
|
||||
void* params
|
||||
);
|
||||
|
||||
#undef GENTFUNC
|
||||
#define GENTFUNC(ctype,ch,op) \
|
||||
static auto PASTEMAC(ch,op) = &packm_diag_ukr<ctype>;
|
||||
|
||||
INSERT_GENTFUNC_BASIC0(packm_diag_ukr);
|
||||
|
||||
static packm_ker_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr );
|
||||
|
||||
/*
|
||||
* Structure which includes all additional information beyond what is
|
||||
* already stored in the obj_t structure.
|
||||
*
|
||||
* This structure is **read-only** during the operation!
|
||||
*/
|
||||
struct packm_diag_params_t : packm_blk_var1_params_t
|
||||
{
|
||||
void* d;
|
||||
inc_t incd;
|
||||
|
||||
packm_diag_params_t() {}
|
||||
|
||||
packm_diag_params_t( void* d, inc_t incd )
|
||||
: d(d), incd(incd)
|
||||
{
|
||||
for ( int i = BLIS_DT_LO; i <= BLIS_DT_HI; i++ )
|
||||
ukr_fn[i][i] = packm_diag_ukrs[i];
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* Selecting a different kernel based on the current architecture is
|
||||
* currently not possible, but is something we plan to support.
|
||||
*/
|
||||
template <typename T>
|
||||
void packm_diag_ukr
|
||||
(
|
||||
struc_t /*struca*/,
|
||||
diag_t /*diaga*/,
|
||||
uplo_t /*uploa*/,
|
||||
conj_t conja,
|
||||
pack_t schema,
|
||||
bool /*invdiag*/,
|
||||
dim_t panel_dim,
|
||||
dim_t panel_len,
|
||||
dim_t panel_dim_max,
|
||||
dim_t panel_len_max,
|
||||
dim_t /*panel_dim_off*/,
|
||||
dim_t panel_len_off,
|
||||
void* restrict kappa,
|
||||
void* restrict a, inc_t inca, inc_t lda,
|
||||
void* restrict p, inc_t ldp,
|
||||
inc_t /*is_p*/,
|
||||
cntx_t* /*cntx*/,
|
||||
void* params
|
||||
)
|
||||
{
|
||||
auto params_cast = ( packm_diag_params_t* )params;
|
||||
T* restrict a_cast = ( T* )a;
|
||||
T* restrict p_cast = ( T* )p;
|
||||
T* restrict d_cast = ( T* )params_cast->d;
|
||||
auto incd = params_cast->incd;
|
||||
auto kappa_cast = *( T* )kappa;
|
||||
|
||||
if ( schema != BLIS_PACKED_ROW_PANELS &&
|
||||
schema != BLIS_PACKED_COL_PANELS )
|
||||
bli_abort();
|
||||
|
||||
/* Apply the offset */
|
||||
d_cast += panel_len_off * incd;
|
||||
|
||||
if ( conja )
|
||||
{
|
||||
for ( dim_t j = 0; j < panel_len; j++ )
|
||||
{
|
||||
auto kappa_d = kappa_cast * d_cast[ j*incd ];
|
||||
|
||||
for (dim_t i = 0;i < panel_dim;i++)
|
||||
p_cast[ i + j*ldp ] = kappa_d * conj( a_cast[ i*inca + j*lda ] );
|
||||
|
||||
for (dim_t i = panel_dim;i < panel_dim_max;i++)
|
||||
p_cast[ i + j*ldp ] = convert<T>(0.0);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for ( dim_t j = 0; j < panel_len; j++ )
|
||||
{
|
||||
auto kappa_d = kappa_cast * d_cast[ j*incd ];
|
||||
|
||||
for (dim_t i = 0;i < panel_dim;i++)
|
||||
p_cast[ i + j*ldp ] = kappa_d * a_cast[ i*inca + j*lda ];
|
||||
|
||||
for (dim_t i = panel_dim;i < panel_dim_max;i++)
|
||||
p_cast[ i + j*ldp ] = convert<T>(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
for (dim_t j = panel_len;j < panel_len_max;j++)
|
||||
for (dim_t i = 0;i < panel_dim_max;i++)
|
||||
p_cast[ i + j*ldp ] = convert<T>(0.0);
|
||||
}
|
||||
|
||||
/*
|
||||
* Modify the object A to include information about the diagonal D,
|
||||
* and imbue it with special function pointers which will take care
|
||||
* of the actual work of forming (D * A^T)
|
||||
*/
|
||||
void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a )
|
||||
{
|
||||
// Assumes D is a column vector
|
||||
new (params) packm_diag_params_t
|
||||
(
|
||||
bli_obj_buffer_at_off( d ),
|
||||
bli_obj_row_stride( d )
|
||||
);
|
||||
|
||||
// Attach the parameters to the A object.
|
||||
bli_obj_set_pack_params( params, a );
|
||||
}
|
||||
|
||||
/*
|
||||
* Implements C := alpha * A * D * A^T + beta * C
|
||||
*
|
||||
* where D is a diagonal matrix with elements taken from the "d" vector.
|
||||
*/
|
||||
void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c )
|
||||
{
|
||||
obj_t ad; // this is (D * A^T)
|
||||
packm_diag_params_t params;
|
||||
|
||||
bli_obj_alias_to( a, &ad );
|
||||
bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T
|
||||
attach_diagonal_factor( ¶ms, d, &ad );
|
||||
|
||||
// Does C := alpha * A * B + beta * C using B = (D + A^T)
|
||||
bli_gemmtnat( alpha, a, &ad, beta, c, NULL, NULL );
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
obj_t a;
|
||||
obj_t d;
|
||||
obj_t c;
|
||||
obj_t c_copy;
|
||||
obj_t norm;
|
||||
|
||||
auto m = 10;
|
||||
auto k = 10;
|
||||
|
||||
for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ )
|
||||
for ( int upper = 0; upper <= 1; upper++ )
|
||||
for ( int transa = 0; transa <= 1; transa++ )
|
||||
for ( int transc = 0; transc <= 1; transc++ )
|
||||
{
|
||||
auto dt = ( num_t )dt_;
|
||||
auto uplo = upper ? BLIS_UPPER : BLIS_LOWER;
|
||||
|
||||
bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a );
|
||||
bli_obj_create( dt, k, 1, 1, 1, &d );
|
||||
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c );
|
||||
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy );
|
||||
bli_obj_set_struc( BLIS_SYMMETRIC , &c );
|
||||
bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy );
|
||||
bli_obj_set_uplo( uplo , &c );
|
||||
bli_obj_set_uplo( uplo , &c_copy );
|
||||
bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm );
|
||||
|
||||
bli_randm( &a );
|
||||
bli_randm( &d );
|
||||
bli_randm( &c );
|
||||
bli_copym( &c, &c_copy );
|
||||
|
||||
syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c );
|
||||
syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy );
|
||||
|
||||
bli_subm( &c_copy, &c );
|
||||
bli_normfm( &c, &norm );
|
||||
|
||||
double normr, normi;
|
||||
bli_getsc( &norm, &normr, &normi );
|
||||
|
||||
printf("dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n",
|
||||
dt, upper, transa, transc, normr);
|
||||
|
||||
bli_obj_free( &a );
|
||||
bli_obj_free( &d );
|
||||
bli_obj_free( &c );
|
||||
bli_obj_free( &c_copy );
|
||||
bli_obj_free( &norm );
|
||||
}
|
||||
}
|
||||
354
test/syrk_diagonal/syrk_diagonal_example2.c
Normal file
354
test/syrk_diagonal/syrk_diagonal_example2.c
Normal file
@@ -0,0 +1,354 @@
|
||||
#include "syrk_diagonal_ref.h"
|
||||
|
||||
/*
|
||||
* Structure which includes all additional information beyond what is
|
||||
* already stored in the obj_t structure.
|
||||
*
|
||||
* This structure is **read-only** during the operation!
|
||||
*/
|
||||
typedef struct packm_diag_params_t
|
||||
{
|
||||
void* d;
|
||||
inc_t incd;
|
||||
} packm_diag_params_t;
|
||||
|
||||
typedef void (*packm_diag_ukr_vft)
|
||||
(
|
||||
bool conja,
|
||||
dim_t panel_dim,
|
||||
dim_t panel_len,
|
||||
dim_t panel_dim_max,
|
||||
dim_t panel_len_max,
|
||||
void* restrict kappa,
|
||||
void* restrict d, inc_t incd,
|
||||
void* restrict a, inc_t inca, inc_t lda,
|
||||
void* restrict p, inc_t ldp
|
||||
);
|
||||
|
||||
/*
|
||||
* Declare the pack kernel type and set up and array of
|
||||
* packing kernels, one for each data type.
|
||||
*/
|
||||
#undef GENTFUNC
|
||||
#define GENTFUNC(ctype,ch,op) \
|
||||
void PASTEMAC(ch,op) \
|
||||
( \
|
||||
bool conja, \
|
||||
dim_t panel_dim, \
|
||||
dim_t panel_len, \
|
||||
dim_t panel_dim_max, \
|
||||
dim_t panel_len_max, \
|
||||
void* restrict kappa, \
|
||||
void* restrict d, inc_t incd, \
|
||||
void* restrict a, inc_t inca, inc_t lda, \
|
||||
void* restrict p, inc_t ldp \
|
||||
) \
|
||||
{ \
|
||||
ctype* restrict a_cast = a; \
|
||||
ctype* restrict p_cast = p; \
|
||||
ctype* restrict d_cast = d; \
|
||||
ctype kappa_cast = *( ctype* )kappa; \
|
||||
\
|
||||
if ( conja ) \
|
||||
{ \
|
||||
for ( dim_t j = 0; j < panel_len; j++ ) \
|
||||
{ \
|
||||
ctype kappa_d; \
|
||||
PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \
|
||||
\
|
||||
for (dim_t i = 0;i < panel_dim;i++) \
|
||||
PASTEMAC(ch,scal2js)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \
|
||||
\
|
||||
for (dim_t i = panel_dim;i < panel_dim_max;i++) \
|
||||
PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
for ( dim_t j = 0; j < panel_len; j++ ) \
|
||||
{ \
|
||||
ctype kappa_d; \
|
||||
PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \
|
||||
\
|
||||
for (dim_t i = 0;i < panel_dim;i++) \
|
||||
PASTEMAC(ch,scal2s)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \
|
||||
\
|
||||
for (dim_t i = panel_dim;i < panel_dim_max;i++) \
|
||||
PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
for (dim_t j = panel_len;j < panel_len_max;j++) \
|
||||
for (dim_t i = 0;i < panel_dim_max;i++) \
|
||||
PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \
|
||||
}
|
||||
|
||||
INSERT_GENTFUNC_BASIC0(packm_diag_ukr);
|
||||
|
||||
static packm_diag_ukr_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr );
|
||||
|
||||
void packm_diag
|
||||
(
|
||||
obj_t* a,
|
||||
obj_t* p,
|
||||
cntx_t* cntx,
|
||||
rntm_t* rntm,
|
||||
cntl_t* cntl,
|
||||
thrinfo_t* thread
|
||||
)
|
||||
{
|
||||
#if 1
|
||||
|
||||
// We begin by copying the fields of A.
|
||||
bli_obj_alias_to( a, p );
|
||||
|
||||
// Get information about data types.
|
||||
num_t dt = bli_obj_dt( a );
|
||||
num_t dt_tar = bli_obj_target_dt( a );
|
||||
num_t dt_scalar = bli_obj_scalar_dt( a );
|
||||
dim_t dt_size = bli_dt_size( dt );
|
||||
|
||||
if ( dt_scalar != dt || dt_tar != dt )
|
||||
bli_abort();
|
||||
|
||||
// Extract various fields from the control tree.
|
||||
bszid_t bmult_id_m = bli_cntl_packm_params_bmid_m( cntl );
|
||||
bszid_t bmult_id_n = bli_cntl_packm_params_bmid_n( cntl );
|
||||
pack_t schema = bli_cntl_packm_params_pack_schema( cntl );
|
||||
dim_t bmult_m_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_m, cntx );
|
||||
dim_t bmult_m_pack = bli_cntx_get_blksz_max_dt( dt_tar, bmult_id_m, cntx );
|
||||
dim_t bmult_n_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_n, cntx );
|
||||
|
||||
if ( schema != BLIS_PACKED_ROW_PANELS &&
|
||||
schema != BLIS_PACKED_COL_PANELS )
|
||||
bli_abort();
|
||||
|
||||
// Store the pack schema to the object.
|
||||
bli_obj_set_pack_schema( schema, p );
|
||||
|
||||
// Clear the conjugation field from the object since matrix packing
|
||||
// in BLIS is deemed to take care of all conjugation necessary.
|
||||
bli_obj_set_conj( BLIS_NO_CONJUGATE, p );
|
||||
|
||||
// If we are packing micropanels, mark P as dense.
|
||||
bli_obj_set_uplo( BLIS_DENSE, p );
|
||||
|
||||
// Reset the view offsets to (0,0).
|
||||
bli_obj_set_offs( 0, 0, p );
|
||||
|
||||
// Compute the dimensions padded by the dimension multiples. These
|
||||
// dimensions will be the dimensions of the packed matrices, including
|
||||
// zero-padding, and will be used by the macro- and micro-kernels.
|
||||
// We compute them by starting with the effective dimensions of A (now
|
||||
// in P) and aligning them to the dimension multiples (typically equal
|
||||
// to register blocksizes). This does waste a little bit of space for
|
||||
// level-2 operations, but that's okay with us.
|
||||
dim_t m_p = bli_obj_length( p );
|
||||
dim_t n_p = bli_obj_width( p );
|
||||
dim_t m_p_pad = bli_align_dim_to_mult( m_p, bmult_m_def );
|
||||
dim_t n_p_pad = bli_align_dim_to_mult( n_p, bmult_n_def );
|
||||
|
||||
// Save the padded dimensions into the packed object. It is important
|
||||
// to save these dimensions since they represent the actual dimensions
|
||||
// of the zero-padded matrix.
|
||||
bli_obj_set_padded_dims( m_p_pad, n_p_pad, p );
|
||||
|
||||
// The "panel stride" of a micropanel packed object is interpreted as
|
||||
// the distance between the (0,0) element of panel k and the (0,0)
|
||||
// element of panel k+1. We use the padded width computed above to
|
||||
// allow for zero-padding (if necessary/desired) along the far end
|
||||
// of each micropanel (ie: the right edge of the matrix). Zero-padding
|
||||
// can also occur along the long edge of the last micropanel if the m
|
||||
// dimension of the matrix is not a whole multiple of MR.
|
||||
inc_t ps_p = bmult_m_pack * n_p_pad;
|
||||
|
||||
/* Compute the total number of iterations we'll need. */
|
||||
dim_t n_iter = m_p_pad / bmult_m_def;
|
||||
|
||||
// Store the strides and panel dimension in P.
|
||||
bli_obj_set_strides( 1, bmult_m_pack, p );
|
||||
bli_obj_set_imag_stride( 1, p );
|
||||
bli_obj_set_panel_dim( bmult_m_def, p );
|
||||
bli_obj_set_panel_stride( ps_p, p );
|
||||
bli_obj_set_panel_length( bmult_m_def, p );
|
||||
bli_obj_set_panel_width( n_p, p );
|
||||
|
||||
// Compute the size of the packed buffer.
|
||||
siz_t size_p = ps_p * n_iter * dt_size;
|
||||
if ( size_p == 0 ) return;
|
||||
|
||||
// Update the buffer address in p to point to the buffer associated
|
||||
// with the mem_t entry acquired from the memory broker (now cached in
|
||||
// the control tree node).
|
||||
char* p_cast = (char*)bli_packm_alloc( size_p, rntm, cntl, thread );
|
||||
bli_obj_set_buffer( p_cast, p );
|
||||
|
||||
#else
|
||||
|
||||
// Every thread initializes p and determines the size of memory
|
||||
// block needed (which gets embedded into the otherwise "blank" mem_t
|
||||
// entry in the control tree node). Return early if no packing is required.
|
||||
if ( !bli_packm_init( a, p, cntx, rntm, cntl, thread ) )
|
||||
return;
|
||||
|
||||
num_t dt = bli_obj_dt( a );
|
||||
dim_t dt_size = bli_dt_size( dt );
|
||||
|
||||
bszid_t bmult_id_m = bli_cntl_packm_params_bmid_m( cntl );
|
||||
dim_t bmult_m_def = bli_cntx_get_blksz_def_dt( dt, bmult_id_m, cntx );
|
||||
dim_t bmult_m_pack = bli_cntx_get_blksz_max_dt( dt, bmult_id_m, cntx );
|
||||
|
||||
dim_t m_p = bli_obj_length( p );
|
||||
dim_t n_p = bli_obj_width( p );
|
||||
dim_t m_p_pad = bli_obj_padded_length( p );
|
||||
dim_t n_p_pad = bli_obj_padded_width( p );
|
||||
dim_t n_iter = m_p_pad / bmult_m_def;
|
||||
|
||||
char* p_cast = bli_obj_buffer( p );
|
||||
inc_t ps_p = bli_obj_panel_stride( p );
|
||||
|
||||
#endif
|
||||
|
||||
char* a_cast = bli_obj_buffer_at_off( a );
|
||||
inc_t inca = bli_obj_row_stride( a );
|
||||
inc_t lda = bli_obj_col_stride( a );
|
||||
dim_t panel_len_off = bli_obj_col_off( a );
|
||||
conj_t conja = bli_obj_conj_status( a );
|
||||
|
||||
packm_diag_params_t* params = bli_obj_pack_params( a );
|
||||
char* d_cast = params->d;
|
||||
inc_t incd = params->incd;
|
||||
|
||||
obj_t kappa_local;
|
||||
char* kappa_cast = bli_packm_scalar( &kappa_local, p );
|
||||
|
||||
packm_diag_ukr_vft packm_ker_cast = packm_diag_ukrs[ dt ];
|
||||
|
||||
/* Query the number of threads and thread ids from the current thread's
|
||||
packm thrinfo_t node. */
|
||||
const dim_t nt = bli_thread_n_way( thread );
|
||||
const dim_t tid = bli_thread_work_id( thread );
|
||||
|
||||
/* Determine the thread range and increment using the current thread's
|
||||
packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir()
|
||||
will depend on whether slab or round-robin partitioning was requested
|
||||
at configure-time. */
|
||||
dim_t it_start, it_end, it_inc;
|
||||
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc );
|
||||
|
||||
/* Iterate over every logical micropanel in the source matrix. */
|
||||
for ( dim_t it = 0; it < n_iter; it += 1 )
|
||||
{
|
||||
dim_t panel_dim_i = bli_min( bmult_m_def, m_p - it*bmult_m_def );
|
||||
|
||||
char* d_begin = d_cast + panel_len_off*incd*dt_size;
|
||||
char* a_begin = a_cast + it* bmult_m_def*inca*dt_size;
|
||||
char* p_begin = p_cast + it* ps_p*dt_size;
|
||||
|
||||
if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) )
|
||||
{
|
||||
packm_ker_cast
|
||||
(
|
||||
conja,
|
||||
panel_dim_i,
|
||||
n_p,
|
||||
bmult_m_def,
|
||||
n_p_pad,
|
||||
kappa_cast,
|
||||
d_begin, incd,
|
||||
a_begin, inca, lda,
|
||||
p_begin, bmult_m_pack
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Modify the object A to include information about the diagonal D,
|
||||
* and imbue it with special function pointers which will take care
|
||||
* of the actual work of forming (D * A^T)
|
||||
*/
|
||||
void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a )
|
||||
{
|
||||
// Assumes D is a column vector
|
||||
params->d = bli_obj_buffer_at_off( d );
|
||||
params->incd = bli_obj_row_stride( d );
|
||||
|
||||
// Set the custom pack function.
|
||||
bli_obj_set_pack_fn( packm_diag, a );
|
||||
|
||||
// Attach the parameters to the A object.
|
||||
bli_obj_set_pack_params( params, a );
|
||||
}
|
||||
|
||||
/*
|
||||
* Implements C := alpha * A * D * A^T + beta * C
|
||||
*
|
||||
* where D is a diagonal matrix with elements taken from the "d" vector.
|
||||
*/
|
||||
void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c )
|
||||
{
|
||||
obj_t ad; // this is (D * A^T)
|
||||
packm_diag_params_t params;
|
||||
|
||||
bli_obj_alias_to( a, &ad );
|
||||
bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T
|
||||
attach_diagonal_factor( ¶ms, d, &ad );
|
||||
|
||||
// Does C := alpha * A * B + beta * C using B = (D + A^T)
|
||||
bli_gemmt( alpha, a, &ad, beta, c );
|
||||
}
|
||||
|
||||
int main( void )
|
||||
{
|
||||
obj_t a;
|
||||
obj_t d;
|
||||
obj_t c;
|
||||
obj_t c_copy;
|
||||
obj_t norm;
|
||||
|
||||
dim_t m = 10;
|
||||
dim_t k = 10;
|
||||
|
||||
for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ )
|
||||
for ( int upper = 0; upper <= 1; upper++ )
|
||||
for ( int transa = 0; transa <= 1; transa++ )
|
||||
for ( int transc = 0; transc <= 1; transc++ )
|
||||
{
|
||||
num_t dt = dt_;
|
||||
uplo_t uplo = upper ? BLIS_UPPER : BLIS_LOWER;
|
||||
|
||||
bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a );
|
||||
bli_obj_create( dt, k, 1, 1, 1, &d );
|
||||
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c );
|
||||
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy );
|
||||
bli_obj_set_struc( BLIS_SYMMETRIC , &c );
|
||||
bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy );
|
||||
bli_obj_set_uplo( uplo , &c );
|
||||
bli_obj_set_uplo( uplo , &c_copy );
|
||||
bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm );
|
||||
|
||||
bli_randm( &a );
|
||||
bli_randm( &d );
|
||||
bli_randm( &c );
|
||||
bli_copym( &c, &c_copy );
|
||||
|
||||
syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c );
|
||||
syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy );
|
||||
|
||||
bli_subm( &c_copy, &c );
|
||||
bli_normfm( &c, &norm );
|
||||
|
||||
double normr, normi;
|
||||
bli_getsc( &norm, &normr, &normi );
|
||||
|
||||
printf( "dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n",
|
||||
dt, upper, transa, transc, normr );
|
||||
|
||||
bli_obj_free( &a );
|
||||
bli_obj_free( &d );
|
||||
bli_obj_free( &c );
|
||||
bli_obj_free( &c_copy );
|
||||
bli_obj_free( &norm );
|
||||
}
|
||||
}
|
||||
338
test/syrk_diagonal/syrk_diagonal_example2.cxx
Normal file
338
test/syrk_diagonal/syrk_diagonal_example2.cxx
Normal file
@@ -0,0 +1,338 @@
|
||||
#include "syrk_diagonal_ref.h"
|
||||
|
||||
/*
|
||||
* Forward-declare the pack kernel type and set up and array of
|
||||
* packing kernels, one for each data type.
|
||||
*/
|
||||
template <typename T>
|
||||
void packm_diag_ukr
|
||||
(
|
||||
bool conja,
|
||||
dim_t panel_dim,
|
||||
dim_t panel_len,
|
||||
dim_t panel_dim_max,
|
||||
dim_t panel_len_max,
|
||||
void* restrict kappa,
|
||||
void* restrict d, inc_t incd,
|
||||
void* restrict a, inc_t inca, inc_t lda,
|
||||
void* restrict p, inc_t ldp
|
||||
);
|
||||
|
||||
#undef GENTFUNC
|
||||
#define GENTFUNC(ctype,ch,op) \
|
||||
static auto PASTEMAC(ch,op) = &packm_diag_ukr<ctype>;
|
||||
|
||||
INSERT_GENTFUNC_BASIC0(packm_diag_ukr);
|
||||
|
||||
using packm_diag_ukr_vft = decltype(&packm_diag_ukr<void>);
|
||||
static packm_diag_ukr_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr );
|
||||
|
||||
/*
|
||||
* Structure which includes all additional information beyond what is
|
||||
* already stored in the obj_t structure.
|
||||
*
|
||||
* This structure is **read-only** during the operation!
|
||||
*/
|
||||
struct packm_diag_params_t
|
||||
{
|
||||
void* d;
|
||||
inc_t incd;
|
||||
|
||||
packm_diag_params_t() {}
|
||||
|
||||
packm_diag_params_t( void* d, inc_t incd )
|
||||
: d(d), incd(incd) {}
|
||||
};
|
||||
|
||||
/*
|
||||
* Selecting a different kernel based on the current architecture is
|
||||
* currently not possible, but is something we plan to support.
|
||||
*/
|
||||
template <typename T>
|
||||
void packm_diag_ukr
|
||||
(
|
||||
bool conja,
|
||||
dim_t panel_dim,
|
||||
dim_t panel_len,
|
||||
dim_t panel_dim_max,
|
||||
dim_t panel_len_max,
|
||||
void* restrict kappa,
|
||||
void* restrict d, inc_t incd,
|
||||
void* restrict a, inc_t inca, inc_t lda,
|
||||
void* restrict p, inc_t ldp
|
||||
)
|
||||
{
|
||||
T* restrict a_cast = ( T* )a;
|
||||
T* restrict p_cast = ( T* )p;
|
||||
T* restrict d_cast = ( T* )d;
|
||||
auto kappa_cast = *( T* )kappa;
|
||||
|
||||
if ( conja )
|
||||
{
|
||||
for ( dim_t j = 0; j < panel_len; j++ )
|
||||
{
|
||||
auto kappa_d = kappa_cast * d_cast[ j*incd ];
|
||||
|
||||
for (dim_t i = 0;i < panel_dim;i++)
|
||||
p_cast[ i + j*ldp ] = kappa_d * conj( a_cast[ i*inca + j*lda ] );
|
||||
|
||||
for (dim_t i = panel_dim;i < panel_dim_max;i++)
|
||||
p_cast[ i + j*ldp ] = convert<T>(0.0);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for ( dim_t j = 0; j < panel_len; j++ )
|
||||
{
|
||||
auto kappa_d = kappa_cast * d_cast[ j*incd ];
|
||||
|
||||
for (dim_t i = 0;i < panel_dim;i++)
|
||||
p_cast[ i + j*ldp ] = kappa_d * a_cast[ i*inca + j*lda ];
|
||||
|
||||
for (dim_t i = panel_dim;i < panel_dim_max;i++)
|
||||
p_cast[ i + j*ldp ] = convert<T>(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
for (dim_t j = panel_len;j < panel_len_max;j++)
|
||||
for (dim_t i = 0;i < panel_dim_max;i++)
|
||||
p_cast[ i + j*ldp ] = convert<T>(0.0);
|
||||
}
|
||||
|
||||
void packm_diag
|
||||
(
|
||||
obj_t* a,
|
||||
obj_t* p,
|
||||
cntx_t* cntx,
|
||||
rntm_t* rntm,
|
||||
cntl_t* cntl,
|
||||
thrinfo_t* thread
|
||||
)
|
||||
{
|
||||
// We begin by copying the fields of A.
|
||||
bli_obj_alias_to( a, p );
|
||||
|
||||
// Get information about data types.
|
||||
num_t dt = bli_obj_dt( a );
|
||||
num_t dt_tar = bli_obj_target_dt( a );
|
||||
num_t dt_scalar = bli_obj_scalar_dt( a );
|
||||
dim_t dt_size = bli_dt_size( dt );
|
||||
|
||||
if ( dt_scalar != dt || dt_tar != dt )
|
||||
bli_abort();
|
||||
|
||||
// Extract various fields from the control tree.
|
||||
bszid_t bmult_id_m = bli_cntl_packm_params_bmid_m( cntl );
|
||||
bszid_t bmult_id_n = bli_cntl_packm_params_bmid_n( cntl );
|
||||
pack_t schema = bli_cntl_packm_params_pack_schema( cntl );
|
||||
dim_t bmult_m_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_m, cntx );
|
||||
dim_t bmult_m_pack = bli_cntx_get_blksz_max_dt( dt_tar, bmult_id_m, cntx );
|
||||
dim_t bmult_n_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_n, cntx );
|
||||
|
||||
if ( schema != BLIS_PACKED_ROW_PANELS &&
|
||||
schema != BLIS_PACKED_COL_PANELS )
|
||||
bli_abort();
|
||||
|
||||
// Store the pack schema to the object.
|
||||
bli_obj_set_pack_schema( schema, p );
|
||||
|
||||
// Clear the conjugation field from the object since matrix packing
|
||||
// in BLIS is deemed to take care of all conjugation necessary.
|
||||
bli_obj_set_conj( BLIS_NO_CONJUGATE, p );
|
||||
|
||||
// If we are packing micropanels, mark P as dense.
|
||||
bli_obj_set_uplo( BLIS_DENSE, p );
|
||||
|
||||
// Reset the view offsets to (0,0).
|
||||
bli_obj_set_offs( 0, 0, p );
|
||||
|
||||
// Compute the dimensions padded by the dimension multiples. These
|
||||
// dimensions will be the dimensions of the packed matrices, including
|
||||
// zero-padding, and will be used by the macro- and micro-kernels.
|
||||
// We compute them by starting with the effective dimensions of A (now
|
||||
// in P) and aligning them to the dimension multiples (typically equal
|
||||
// to register blocksizes). This does waste a little bit of space for
|
||||
// level-2 operations, but that's okay with us.
|
||||
dim_t m_p = bli_obj_length( p );
|
||||
dim_t n_p = bli_obj_width( p );
|
||||
dim_t m_p_pad = bli_align_dim_to_mult( m_p, bmult_m_def );
|
||||
dim_t n_p_pad = bli_align_dim_to_mult( n_p, bmult_n_def );
|
||||
|
||||
// Save the padded dimensions into the packed object. It is important
|
||||
// to save these dimensions since they represent the actual dimensions
|
||||
// of the zero-padded matrix.
|
||||
bli_obj_set_padded_dims( m_p_pad, n_p_pad, p );
|
||||
|
||||
// The "panel stride" of a micropanel packed object is interpreted as
|
||||
// the distance between the (0,0) element of panel k and the (0,0)
|
||||
// element of panel k+1. We use the padded width computed above to
|
||||
// allow for zero-padding (if necessary/desired) along the far end
|
||||
// of each micropanel (ie: the right edge of the matrix). Zero-padding
|
||||
// can also occur along the long edge of the last micropanel if the m
|
||||
// dimension of the matrix is not a whole multiple of MR.
|
||||
inc_t ps_p = bmult_m_pack * n_p_pad;
|
||||
|
||||
/* Compute the total number of iterations we'll need. */
|
||||
dim_t n_iter = m_p_pad / bmult_m_def;
|
||||
|
||||
// Store the strides and panel dimension in P.
|
||||
bli_obj_set_strides( 1, bmult_m_pack, p );
|
||||
bli_obj_set_imag_stride( 1, p );
|
||||
bli_obj_set_panel_dim( bmult_m_def, p );
|
||||
bli_obj_set_panel_stride( ps_p, p );
|
||||
bli_obj_set_panel_length( bmult_m_def, p );
|
||||
bli_obj_set_panel_width( n_p, p );
|
||||
|
||||
// Compute the size of the packed buffer.
|
||||
siz_t size_p = ps_p * n_iter * dt_size;
|
||||
if ( size_p == 0 ) return;
|
||||
|
||||
// Update the buffer address in p to point to the buffer associated
|
||||
// with the mem_t entry acquired from the memory broker (now cached in
|
||||
// the control tree node).
|
||||
char* p_cast = (char*)bli_packm_alloc( size_p, rntm, cntl, thread );
|
||||
bli_obj_set_buffer( p_cast, p );
|
||||
|
||||
char* a_cast = (char*)bli_obj_buffer_at_off( a );
|
||||
inc_t inca = bli_obj_row_stride( a );
|
||||
inc_t lda = bli_obj_col_stride( a );
|
||||
dim_t panel_len_off = bli_obj_col_off( a );
|
||||
conj_t conja = bli_obj_conj_status( a );
|
||||
|
||||
auto params = (packm_diag_params_t*)bli_obj_pack_params( a );
|
||||
char* d_cast = (char*)params->d;
|
||||
inc_t incd = params->incd;
|
||||
|
||||
obj_t kappa_local;
|
||||
char* kappa_cast = (char*)bli_packm_scalar( &kappa_local, p );
|
||||
|
||||
auto packm_ker_cast = packm_diag_ukrs[ dt ];
|
||||
|
||||
/* Query the number of threads and thread ids from the current thread's
|
||||
packm thrinfo_t node. */
|
||||
const dim_t nt = bli_thread_n_way( thread );
|
||||
const dim_t tid = bli_thread_work_id( thread );
|
||||
|
||||
/* Determine the thread range and increment using the current thread's
|
||||
packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir()
|
||||
will depend on whether slab or round-robin partitioning was requested
|
||||
at configure-time. */
|
||||
dim_t it_start, it_end, it_inc;
|
||||
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc );
|
||||
|
||||
/* Iterate over every logical micropanel in the source matrix. */
|
||||
for ( dim_t it = 0; it < n_iter; it += 1 )
|
||||
{
|
||||
dim_t panel_dim_i = bli_min( bmult_m_def, m_p - it*bmult_m_def );
|
||||
|
||||
char* d_begin = d_cast + panel_len_off*incd*dt_size;
|
||||
char* a_begin = a_cast + it* bmult_m_def*inca*dt_size;
|
||||
char* p_begin = p_cast + it* ps_p*dt_size;
|
||||
|
||||
if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) )
|
||||
{
|
||||
packm_ker_cast( conja,
|
||||
panel_dim_i,
|
||||
n_p,
|
||||
bmult_m_def,
|
||||
n_p_pad,
|
||||
kappa_cast,
|
||||
d_begin, incd,
|
||||
a_begin, inca, lda,
|
||||
p_begin, bmult_m_pack );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Modify the object A to include information about the diagonal D,
|
||||
* and imbue it with special function pointers which will take care
|
||||
* of the actual work of forming (D * A^T)
|
||||
*/
|
||||
void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a )
|
||||
{
|
||||
// Assumes D is a column vector
|
||||
new (params) packm_diag_params_t
|
||||
(
|
||||
bli_obj_buffer_at_off( d ),
|
||||
bli_obj_row_stride( d )
|
||||
);
|
||||
|
||||
// Set the custom pack function.
|
||||
bli_obj_set_pack_fn( packm_diag, a );
|
||||
|
||||
// Attach the parameters to the A object.
|
||||
bli_obj_set_pack_params( params, a );
|
||||
}
|
||||
|
||||
/*
|
||||
* Implements C := alpha * A * D * A^T + beta * C
|
||||
*
|
||||
* where D is a diagonal matrix with elements taken from the "d" vector.
|
||||
*/
|
||||
void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c )
|
||||
{
|
||||
obj_t ad; // this is (D * A^T)
|
||||
packm_diag_params_t params;
|
||||
|
||||
bli_obj_alias_to( a, &ad );
|
||||
bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T
|
||||
attach_diagonal_factor( ¶ms, d, &ad );
|
||||
|
||||
// Does C := alpha * A * B + beta * C using B = (D + A^T)
|
||||
bli_gemmt( alpha, a, &ad, beta, c );
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
obj_t a;
|
||||
obj_t d;
|
||||
obj_t c;
|
||||
obj_t c_copy;
|
||||
obj_t norm;
|
||||
|
||||
auto m = 10;
|
||||
auto k = 10;
|
||||
|
||||
for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ )
|
||||
for ( int upper = 0; upper <= 1; upper++ )
|
||||
for ( int transa = 0; transa <= 1; transa++ )
|
||||
for ( int transc = 0; transc <= 1; transc++ )
|
||||
{
|
||||
auto dt = ( num_t )dt_;
|
||||
auto uplo = upper ? BLIS_UPPER : BLIS_LOWER;
|
||||
|
||||
bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a );
|
||||
bli_obj_create( dt, k, 1, 1, 1, &d );
|
||||
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c );
|
||||
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy );
|
||||
bli_obj_set_struc( BLIS_SYMMETRIC , &c );
|
||||
bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy );
|
||||
bli_obj_set_uplo( uplo , &c );
|
||||
bli_obj_set_uplo( uplo , &c_copy );
|
||||
bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm );
|
||||
|
||||
bli_randm( &a );
|
||||
bli_randm( &d );
|
||||
bli_randm( &c );
|
||||
bli_copym( &c, &c_copy );
|
||||
|
||||
syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c );
|
||||
syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy );
|
||||
|
||||
bli_subm( &c_copy, &c );
|
||||
bli_normfm( &c, &norm );
|
||||
|
||||
double normr, normi;
|
||||
bli_getsc( &norm, &normr, &normi );
|
||||
|
||||
printf("dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n",
|
||||
dt, upper, transa, transc, normr);
|
||||
|
||||
bli_obj_free( &a );
|
||||
bli_obj_free( &d );
|
||||
bli_obj_free( &c );
|
||||
bli_obj_free( &c_copy );
|
||||
bli_obj_free( &norm );
|
||||
}
|
||||
}
|
||||
102
test/syrk_diagonal/syrk_diagonal_ref.cxx
Normal file
102
test/syrk_diagonal/syrk_diagonal_ref.cxx
Normal file
@@ -0,0 +1,102 @@
|
||||
#include "syrk_diagonal_ref.h"
|
||||
#include "complex_math.hpp"
|
||||
|
||||
typedef void (*syrk_diag_ref_vft)
|
||||
(
|
||||
uplo_t uplo,
|
||||
dim_t m,
|
||||
dim_t k,
|
||||
void* alpha,
|
||||
void* a, inc_t rs_a, inc_t cs_a,
|
||||
void* d, inc_t incd,
|
||||
void* beta,
|
||||
void* c, inc_t rs_c, inc_t cs_c
|
||||
);
|
||||
|
||||
template <typename T>
|
||||
void syrk_diag_ref
|
||||
(
|
||||
uplo_t uplo,
|
||||
dim_t m,
|
||||
dim_t k,
|
||||
void* alpha,
|
||||
void* a, inc_t rs_a, inc_t cs_a,
|
||||
void* d, inc_t incd,
|
||||
void* beta,
|
||||
void* c, inc_t rs_c, inc_t cs_c
|
||||
)
|
||||
{
|
||||
auto alpha_cast = *( T* )alpha;
|
||||
auto beta_cast = *( T* )beta;
|
||||
auto a_cast = ( T* )a;
|
||||
auto d_cast = ( T* )d;
|
||||
auto c_cast = ( T* )c;
|
||||
|
||||
for ( dim_t i = 0; i < m; i++ )
|
||||
{
|
||||
dim_t j_min = uplo == BLIS_UPPER ? i : 0;
|
||||
dim_t j_max = uplo == BLIS_UPPER ? m : i+1;
|
||||
|
||||
for ( dim_t j = j_min; j < j_max; j++ )
|
||||
{
|
||||
auto ada = convert<T>(0.0);
|
||||
|
||||
for ( dim_t p = 0; p < k; p++ )
|
||||
{
|
||||
ada += a_cast[ i*rs_a + p*cs_a ] *
|
||||
d_cast[ p*incd ] *
|
||||
a_cast[ j*rs_a + p*cs_a ];
|
||||
}
|
||||
|
||||
if ( beta_cast == convert<T>(0.0) )
|
||||
{
|
||||
c_cast[ i*rs_c + j*cs_c ] = alpha_cast * ada;
|
||||
}
|
||||
else
|
||||
{
|
||||
c_cast[ i*rs_c + j*cs_c ] = alpha_cast * ada +
|
||||
beta_cast * c_cast[ i*rs_c + j*cs_c ];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#undef GENTFUNC
|
||||
#define GENTFUNC(ctype,ch,op) \
|
||||
static auto PASTEMAC(ch,op) = &syrk_diag_ref<ctype>;
|
||||
|
||||
INSERT_GENTFUNC_BASIC0(syrk_diag_ref);
|
||||
|
||||
static syrk_diag_ref_vft GENARRAY( syrk_diag_ref_impl, syrk_diag_ref );
|
||||
|
||||
void syrk_diag_ref( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c )
|
||||
{
|
||||
num_t dt = bli_obj_dt( a );
|
||||
|
||||
dim_t m = bli_obj_length_after_trans( a );
|
||||
dim_t k = bli_obj_width_after_trans( a );
|
||||
|
||||
inc_t rs_a = bli_obj_row_stride( a );
|
||||
inc_t cs_a = bli_obj_col_stride( a );
|
||||
inc_t rs_c = bli_obj_row_stride( c );
|
||||
inc_t cs_c = bli_obj_col_stride( c );
|
||||
inc_t incd = bli_obj_row_stride( d );
|
||||
|
||||
if ( bli_obj_has_trans( a ) )
|
||||
bli_swap_incs( &rs_a, &cs_a );
|
||||
|
||||
if ( bli_obj_has_trans( c ) )
|
||||
bli_swap_incs( &rs_c, &cs_c );
|
||||
|
||||
syrk_diag_ref_impl[ dt ]
|
||||
(
|
||||
bli_obj_uplo( c ),
|
||||
m, k,
|
||||
bli_obj_buffer_for_1x1( dt, alpha ),
|
||||
bli_obj_buffer_at_off( a ), rs_a, cs_a,
|
||||
bli_obj_buffer_at_off( d ), incd,
|
||||
bli_obj_buffer_for_1x1( dt, beta ),
|
||||
bli_obj_buffer_at_off( c ), rs_c, cs_c
|
||||
);
|
||||
}
|
||||
|
||||
8
test/syrk_diagonal/syrk_diagonal_ref.h
Normal file
8
test/syrk_diagonal/syrk_diagonal_ref.h
Normal file
@@ -0,0 +1,8 @@
|
||||
#include "blis.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
#include "complex_math.hpp"
|
||||
extern "C"
|
||||
#endif
|
||||
void syrk_diag_ref( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c );
|
||||
|
||||
267
test/tensor_contraction/complex_math.hpp
Normal file
267
test/tensor_contraction/complex_math.hpp
Normal file
@@ -0,0 +1,267 @@
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <type_traits>
|
||||
|
||||
#include "blis.h"
|
||||
|
||||
template <typename T>
|
||||
struct is_complex : std::false_type {};
|
||||
|
||||
template <>
|
||||
struct is_complex<scomplex> : std::true_type {};
|
||||
|
||||
template <>
|
||||
struct is_complex<dcomplex> : std::true_type {};
|
||||
|
||||
template <typename T>
|
||||
struct is_real : std::integral_constant<bool,!is_complex<T>::value> {};
|
||||
|
||||
template <typename T> struct make_complex;
|
||||
|
||||
template <> struct make_complex<float > { using type = scomplex; };
|
||||
template <> struct make_complex<double > { using type = dcomplex; };
|
||||
template <> struct make_complex<scomplex> { using type = scomplex; };
|
||||
template <> struct make_complex<dcomplex> { using type = dcomplex; };
|
||||
|
||||
template <typename T>
|
||||
using make_complex_t = typename make_complex<T>::type;
|
||||
|
||||
template <typename T> struct make_real;
|
||||
|
||||
template <> struct make_real<float > { using type = float; };
|
||||
template <> struct make_real<double > { using type = double; };
|
||||
template <> struct make_real<scomplex> { using type = float; };
|
||||
template <> struct make_real<dcomplex> { using type = double; };
|
||||
|
||||
template <typename T>
|
||||
using make_real_t = typename make_real<T>::type;
|
||||
|
||||
template <typename T, bool Cond>
|
||||
struct make_complex_if : std::conditional<Cond,make_complex_t<T>,make_real_t<T>> {};
|
||||
|
||||
template <typename T, bool Cond>
|
||||
using make_complex_if_t = typename make_complex_if<T,Cond>::type;
|
||||
|
||||
template <typename T>
|
||||
struct real_imag_part
|
||||
{
|
||||
real_imag_part& operator=(T) { return *this; }
|
||||
|
||||
operator T() const { return T(); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_arithmetic<typename std::remove_cv<T>::type>::value,T&> real(T& x) { return x; }
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_arithmetic<T>::value,real_imag_part<T>> imag(T x) { return {}; }
|
||||
|
||||
inline float& real(scomplex& x) { return x.real; }
|
||||
|
||||
inline float& imag(scomplex& x) { return x.imag; }
|
||||
|
||||
inline double& real(dcomplex& x) { return x.real; }
|
||||
|
||||
inline double& imag(dcomplex& x) { return x.imag; }
|
||||
|
||||
inline const float& real(const scomplex& x) { return x.real; }
|
||||
|
||||
inline const float& imag(const scomplex& x) { return x.imag; }
|
||||
|
||||
inline const double& real(const dcomplex& x) { return x.real; }
|
||||
|
||||
inline const double& imag(const dcomplex& x) { return x.imag; }
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<is_real<T>::value,T> conj(T x) { return x; }
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<is_complex<T>::value,T> conj(const T& x) { return {x.real, -x.imag}; }
|
||||
|
||||
template <typename T, typename U, typename=void>
|
||||
struct convert_impl;
|
||||
|
||||
template <typename T, typename U>
|
||||
struct convert_impl<T, U, std::enable_if_t<is_real<T>::value && is_real<U>::value>>
|
||||
{
|
||||
void operator()(T x, U& y) const { y = x; }
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
struct convert_impl<T, U, std::enable_if_t<is_real<T>::value && is_complex<U>::value>>
|
||||
{
|
||||
void operator()(T x, U& y) const { y.real = x; y.imag = 0; }
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
struct convert_impl<T, U, std::enable_if_t<is_complex<T>::value && is_real<U>::value>>
|
||||
{
|
||||
void operator()(T x, U& y) const { y = x.real; }
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
struct convert_impl<T, U, std::enable_if_t<is_complex<T>::value && is_complex<U>::value>>
|
||||
{
|
||||
void operator()(T x, U& y) const { y.real = x.real; y.imag = x.imag; }
|
||||
};
|
||||
|
||||
template <typename U, typename T>
|
||||
U convert(T x)
|
||||
{
|
||||
U y;
|
||||
convert_impl<T,U>{}(x,y);
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename U, typename T>
|
||||
auto convert_prec(T x) -> make_complex_if_t<U,is_complex<T>::value>
|
||||
{
|
||||
return convert<make_complex_if_t<U,is_complex<T>::value>>(x);
|
||||
}
|
||||
|
||||
#define COMPLEX_MATH_OPS(rtype, ctype) \
|
||||
\
|
||||
inline bool operator==(rtype x, ctype y) \
|
||||
{ \
|
||||
return x == y.real && y.imag == 0; \
|
||||
} \
|
||||
\
|
||||
inline bool operator==(ctype x, rtype y) \
|
||||
{ \
|
||||
return y == x.real && x.imag == 0; \
|
||||
} \
|
||||
\
|
||||
inline bool operator==(ctype x, ctype y) \
|
||||
{ \
|
||||
return x.real == y.real && \
|
||||
x.imag == y.imag; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator-(ctype x) \
|
||||
{ \
|
||||
return {-x.real, -x.imag}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator+(rtype x, ctype y) \
|
||||
{ \
|
||||
return {x+y.real, y.imag}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator+(ctype x, rtype y) \
|
||||
{ \
|
||||
return {y+x.real, x.imag}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator+(ctype x, ctype y) \
|
||||
{ \
|
||||
return {x.real+y.real, x.imag+y.imag}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator-(rtype x, ctype y) \
|
||||
{ \
|
||||
return {x-y.real, -y.imag}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator-(ctype x, rtype y) \
|
||||
{ \
|
||||
return {x.real-y, x.imag}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator-(ctype x, ctype y) \
|
||||
{ \
|
||||
return {x.real-y.real, x.imag-y.imag}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator*(rtype x, ctype y) \
|
||||
{ \
|
||||
return {x*y.real, x*y.imag}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator*(ctype x, rtype y) \
|
||||
{ \
|
||||
return {y*x.real, y*x.imag}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator*(ctype x, ctype y) \
|
||||
{ \
|
||||
return {x.real*y.real - x.imag*y.imag, \
|
||||
x.real*y.imag + x.imag*y.real}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator/(rtype x, ctype y) \
|
||||
{ \
|
||||
auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \
|
||||
auto n = std::ilogb(scale); \
|
||||
auto yrs = std::scalbn(y.real, -n); \
|
||||
auto yis = std::scalbn(y.imag, -n); \
|
||||
auto denom = y.real*yrs + y.imag*yis; \
|
||||
return {x*yrs/denom, -x*yis/denom}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator/(ctype x, rtype y) \
|
||||
{ \
|
||||
return {x.real/y, x.imag/y}; \
|
||||
} \
|
||||
\
|
||||
inline ctype operator/(ctype x, ctype y) \
|
||||
{ \
|
||||
auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \
|
||||
auto n = std::ilogb(scale); \
|
||||
auto yrs = std::scalbn(y.real, -n); \
|
||||
auto yis = std::scalbn(y.imag, -n); \
|
||||
auto denom = y.real*yrs + y.imag*yis; \
|
||||
return {(x.real*yrs + x.imag*yis)/denom, \
|
||||
(x.imag*yrs - x.real*yis)/denom}; \
|
||||
} \
|
||||
\
|
||||
inline ctype& operator+=(ctype& x, rtype y) \
|
||||
{ \
|
||||
x.real += y; \
|
||||
return x; \
|
||||
} \
|
||||
\
|
||||
inline ctype& operator+=(ctype& x, ctype y) \
|
||||
{ \
|
||||
x.real += y.real; x.imag += y.imag; \
|
||||
return x; \
|
||||
} \
|
||||
\
|
||||
inline ctype& operator-=(ctype& x, rtype y) \
|
||||
{ \
|
||||
x.real -= y; \
|
||||
return x; \
|
||||
} \
|
||||
\
|
||||
inline ctype& operator-=(ctype& x, ctype y) \
|
||||
{ \
|
||||
x.real -= y.real; x.imag -= y.imag; \
|
||||
return x; \
|
||||
} \
|
||||
\
|
||||
inline ctype& operator*=(ctype& x, rtype y) \
|
||||
{ \
|
||||
x.real *= y; x.imag *= y; \
|
||||
return x; \
|
||||
} \
|
||||
\
|
||||
inline ctype& operator*=(ctype& x, ctype y) \
|
||||
{ \
|
||||
x = x * y; \
|
||||
return x; \
|
||||
} \
|
||||
\
|
||||
inline ctype& operator/=(ctype& x, rtype y) \
|
||||
{ \
|
||||
x.real /= y; x.imag /= y; \
|
||||
return x; \
|
||||
} \
|
||||
\
|
||||
inline ctype& operator/=(ctype& x, ctype y) \
|
||||
{ \
|
||||
x = x / y; \
|
||||
return x; \
|
||||
}
|
||||
|
||||
COMPLEX_MATH_OPS(float, scomplex);
|
||||
COMPLEX_MATH_OPS(double, dcomplex);
|
||||
|
||||
988
test/tensor_contraction/tcontract_example.cxx
Normal file
988
test/tensor_contraction/tcontract_example.cxx
Normal file
@@ -0,0 +1,988 @@
|
||||
|
||||
#include "tcontract_ref.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
|
||||
static constexpr dim_t BS_K = 8;
|
||||
|
||||
struct packm_tensor_params_t
|
||||
{
|
||||
gint_t ndim_m, ndim_n;
|
||||
const dim_t *len_m, *len_n;
|
||||
const inc_t *stride_m, *stride_n;
|
||||
|
||||
packm_tensor_params_t() {}
|
||||
|
||||
packm_tensor_params_t( gint_t ndim_m, const dim_t* len_m, const inc_t* stride_m,
|
||||
gint_t ndim_n, const dim_t* len_n, const inc_t* stride_n )
|
||||
: ndim_m(ndim_m), ndim_n(ndim_n),
|
||||
len_m(len_m), len_n(len_n),
|
||||
stride_m(stride_m), stride_n(stride_n) {}
|
||||
};
|
||||
|
||||
using gemm_tensor_params_t = packm_tensor_params_t;
|
||||
|
||||
template <typename T>
|
||||
void packm_ckx_nb
|
||||
(
|
||||
bool conja,
|
||||
dim_t panel_dim,
|
||||
dim_t panel_len,
|
||||
dim_t panel_dim_max,
|
||||
dim_t panel_len_max,
|
||||
void* kappa,
|
||||
void* a, inc_t inca, inc_t* bsa, inc_t* scata,
|
||||
void* p, inc_t ldp
|
||||
)
|
||||
{
|
||||
T* restrict a_cast = ( T* )a;
|
||||
T* restrict p_cast = ( T* )p;
|
||||
auto kappa_cast = *( T* )kappa;
|
||||
|
||||
if ( conja )
|
||||
{
|
||||
for ( auto j0 = 0; j0 < panel_len; j0 += BS_K, bsa += BS_K, scata += BS_K )
|
||||
{
|
||||
auto lda = *bsa;
|
||||
auto panel_len_j = std::min<dim_t>( panel_len-j0, BS_K );
|
||||
|
||||
if ( lda )
|
||||
{
|
||||
T* restrict aj = a_cast + *scata;
|
||||
|
||||
for ( auto j = 0; j < panel_len_j; j++ )
|
||||
{
|
||||
for ( auto i = 0; i < panel_dim; i++ )
|
||||
p_cast[ i ] = kappa_cast * conj( aj[ i*inca + j*lda ] );
|
||||
|
||||
for ( auto i = panel_dim; i < panel_dim_max; i++ )
|
||||
p_cast[ i ] = convert<T>(0.0);
|
||||
|
||||
p_cast += ldp;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for ( auto j = 0; j < panel_len_j; j++)
|
||||
{
|
||||
for ( auto i = 0; i < panel_dim; i++)
|
||||
p_cast[ i ] = kappa_cast * conj( a_cast[ i*inca + scata[j] ] );
|
||||
|
||||
for ( auto i = panel_dim; i < panel_dim_max; i++)
|
||||
p_cast[ i ] = convert<T>(0.0);
|
||||
|
||||
p_cast += ldp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for ( auto j0 = 0; j0 < panel_len; j0 += BS_K, bsa += BS_K, scata += BS_K )
|
||||
{
|
||||
auto lda = *bsa;
|
||||
auto panel_len_j = std::min<dim_t>( panel_len-j0, BS_K );
|
||||
|
||||
if ( lda )
|
||||
{
|
||||
T* restrict aj = a_cast + *scata;
|
||||
|
||||
for ( auto j = 0; j < panel_len_j; j++ )
|
||||
{
|
||||
for ( auto i = 0; i < panel_dim; i++ )
|
||||
p_cast[ i ] = kappa_cast * aj[ i*inca + j*lda ];
|
||||
|
||||
for ( auto i = panel_dim; i < panel_dim_max; i++ )
|
||||
p_cast[ i ] = convert<T>(0.0);
|
||||
|
||||
p_cast += ldp;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for ( auto j = 0; j < panel_len_j; j++ )
|
||||
{
|
||||
for ( auto i = 0; i < panel_dim; i++ )
|
||||
p_cast[ i ] = kappa_cast * a_cast[ i*inca + scata[j] ];
|
||||
|
||||
for ( auto i = panel_dim; i < panel_dim_max; i++ )
|
||||
p_cast[ i ] = convert<T>(0.0);
|
||||
|
||||
p_cast += ldp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for ( auto j = panel_len; j < panel_len_max; j++)
|
||||
{
|
||||
for ( auto i = 0; i < panel_dim_max; i++)
|
||||
p_cast[ i ] = convert<T>(0.0);
|
||||
|
||||
p_cast += ldp;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void packm_ckx_ss
|
||||
(
|
||||
bool conja,
|
||||
dim_t panel_dim,
|
||||
dim_t panel_len,
|
||||
dim_t panel_dim_max,
|
||||
dim_t panel_len_max,
|
||||
void* kappa,
|
||||
void* a, inc_t* inca, inc_t* scata,
|
||||
void* p, inc_t ldp
|
||||
)
|
||||
{
|
||||
T* restrict a_cast = ( T* )a;
|
||||
T* restrict p_cast = ( T* )p;
|
||||
auto kappa_cast = *( T* )kappa;
|
||||
|
||||
if ( conja )
|
||||
{
|
||||
for (dim_t j = 0;j < panel_len;j++)
|
||||
{
|
||||
for (dim_t i = 0;i < panel_dim;i++)
|
||||
p_cast[ i ] = kappa_cast * conj( a_cast[ inca[i] + scata[j] ] );
|
||||
|
||||
for (dim_t i = panel_dim;i < panel_dim_max;i++)
|
||||
p_cast[ i ] = convert<T>(0.0);
|
||||
|
||||
p_cast += ldp;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (dim_t j = 0;j < panel_len;j++)
|
||||
{
|
||||
for (dim_t i = 0;i < panel_dim;i++)
|
||||
p_cast[ i ] = kappa_cast * a_cast[ inca[i] + scata[j] ];
|
||||
|
||||
for (dim_t i = panel_dim;i < panel_dim_max;i++)
|
||||
p_cast[ i ] = convert<T>(0.0);
|
||||
|
||||
p_cast += ldp;
|
||||
}
|
||||
}
|
||||
|
||||
for (dim_t j = panel_len;j < panel_len_max;j++)
|
||||
{
|
||||
for (dim_t i = 0;i < panel_dim_max;i++)
|
||||
p_cast[ i ] = convert<T>(0.0);
|
||||
|
||||
p_cast += ldp;
|
||||
}
|
||||
}
|
||||
|
||||
#undef GENTFUNC
|
||||
#define GENTFUNC(ctype,ch,op) \
|
||||
static auto PASTEMAC(ch,op) = &packm_ckx_nb<ctype>;
|
||||
|
||||
INSERT_GENTFUNC_BASIC0(packm_ckx_nb);
|
||||
|
||||
#undef GENTFUNC
|
||||
#define GENTFUNC(ctype,ch,op) \
|
||||
static auto PASTEMAC(ch,op) = &packm_ckx_ss<ctype>;
|
||||
|
||||
INSERT_GENTFUNC_BASIC0(packm_ckx_ss);
|
||||
|
||||
static decltype(&packm_ckx_nb<void>) GENARRAY( packm_ckx_nb_ukrs, packm_ckx_nb );
|
||||
static decltype(&packm_ckx_ss<void>) GENARRAY( packm_ckx_ss_ukrs, packm_ckx_ss );
|
||||
|
||||
static void fill_scatter
|
||||
(
|
||||
gint_t ndim,
|
||||
const dim_t* restrict len,
|
||||
const inc_t* restrict stride,
|
||||
dim_t BS,
|
||||
inc_t off,
|
||||
dim_t size,
|
||||
inc_t* restrict scat,
|
||||
inc_t* restrict bs
|
||||
)
|
||||
{
|
||||
if ( size == 0 ) return;
|
||||
|
||||
if ( ndim == 0 )
|
||||
{
|
||||
*scat = 0;
|
||||
*bs = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
if ( ndim == 1 )
|
||||
{
|
||||
auto l = *len;
|
||||
auto s = *stride;
|
||||
for ( auto i = 0; i < l; i++ )
|
||||
{
|
||||
scat[i] = i*s;
|
||||
bs[i] = s;
|
||||
}
|
||||
}
|
||||
|
||||
dim_t tot_len = 1;
|
||||
for ( auto i = 0; i < ndim; i++ )
|
||||
tot_len *= len[i];
|
||||
|
||||
assert(off >= 0);
|
||||
assert(size >= 0);
|
||||
assert(off+size <= tot_len);
|
||||
|
||||
auto len0 = len[0];
|
||||
auto stride0 = stride[0];
|
||||
auto off0 = off % len0;
|
||||
auto off1 = off / len0;
|
||||
auto size1 = ( size + off0 + len0 - 1) / len0;
|
||||
|
||||
inc_t pos1 = 0;
|
||||
inc_t idx = 0;
|
||||
for_each( ndim-1, len+1, off1, size1, pos1, stride+1,
|
||||
[&]
|
||||
{
|
||||
auto pos = pos1 + off0 * stride0;
|
||||
auto len_i = std::min( len0-off0, size-idx );
|
||||
for ( auto i = 0; i < len_i; i++ )
|
||||
{
|
||||
scat[idx++] = pos;
|
||||
pos += stride0;
|
||||
}
|
||||
off0 = 0;
|
||||
});
|
||||
assert(idx == size);
|
||||
|
||||
for ( idx = 0; idx < size; idx += BS )
|
||||
{
|
||||
auto len_i = std::min( BS, size-idx );
|
||||
auto s = stride0;
|
||||
|
||||
for ( auto i = idx; i < idx+len_i-1; i++)
|
||||
{
|
||||
if (scat[i+1]-scat[i] != s)
|
||||
{
|
||||
s = 0;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
bs[idx] = s;
|
||||
}
|
||||
}
|
||||
|
||||
void packm_tensor
|
||||
(
|
||||
obj_t* a,
|
||||
obj_t* p,
|
||||
cntx_t* cntx,
|
||||
rntm_t* rntm,
|
||||
cntl_t* cntl,
|
||||
thrinfo_t* thread
|
||||
)
|
||||
{
|
||||
// We begin by copying the fields of A.
|
||||
bli_obj_alias_to( a, p );
|
||||
|
||||
// Get information about data types.
|
||||
auto dt = bli_obj_dt( a );
|
||||
auto dt_tar = bli_obj_target_dt( a );
|
||||
auto dt_scalar = bli_obj_scalar_dt( a );
|
||||
auto dt_size = bli_dt_size( dt );
|
||||
|
||||
if ( dt_scalar != dt || dt_tar != dt )
|
||||
bli_abort();
|
||||
|
||||
// Extract various fields from the control tree.
|
||||
auto bmult_id_m = bli_cntl_packm_params_bmid_m( cntl );
|
||||
auto bmult_id_n = bli_cntl_packm_params_bmid_n( cntl );
|
||||
auto schema = bli_cntl_packm_params_pack_schema( cntl );
|
||||
auto bmult_m_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_m, cntx );
|
||||
auto bmult_m_pack = bli_cntx_get_blksz_max_dt( dt_tar, bmult_id_m, cntx );
|
||||
auto bmult_n_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_n, cntx );
|
||||
|
||||
if ( schema != BLIS_PACKED_ROW_PANELS &&
|
||||
schema != BLIS_PACKED_COL_PANELS )
|
||||
bli_abort();
|
||||
|
||||
// Store the pack schema to the object.
|
||||
bli_obj_set_pack_schema( schema, p );
|
||||
|
||||
// Clear the conjugation field from the object since matrix packing
|
||||
// in BLIS is deemed to take care of all conjugation necessary.
|
||||
bli_obj_set_conj( BLIS_NO_CONJUGATE, p );
|
||||
|
||||
// If we are packing micropanels, mark P as dense.
|
||||
bli_obj_set_uplo( BLIS_DENSE, p );
|
||||
|
||||
// Reset the view offsets to (0,0).
|
||||
bli_obj_set_offs( 0, 0, p );
|
||||
|
||||
// Compute the dimensions padded by the dimension multiples. These
|
||||
// dimensions will be the dimensions of the packed matrices, including
|
||||
// zero-padding, and will be used by the macro- and micro-kernels.
|
||||
// We compute them by starting with the effective dimensions of A (now
|
||||
// in P) and aligning them to the dimension multiples (typically equal
|
||||
// to register blocksizes). This does waste a little bit of space for
|
||||
// level-2 operations, but that's okay with us.
|
||||
auto m_p = bli_obj_length( p );
|
||||
auto n_p = bli_obj_width( p );
|
||||
auto m_p_pad = bli_align_dim_to_mult( m_p, bmult_m_def );
|
||||
auto n_p_pad = bli_align_dim_to_mult( n_p, bmult_n_def );
|
||||
|
||||
// Save the padded dimensions into the packed object. It is important
|
||||
// to save these dimensions since they represent the actual dimensions
|
||||
// of the zero-padded matrix.
|
||||
bli_obj_set_padded_dims( m_p_pad, n_p_pad, p );
|
||||
|
||||
// The "panel stride" of a micropanel packed object is interpreted as
|
||||
// the distance between the (0,0) element of panel k and the (0,0)
|
||||
// element of panel k+1. We use the padded width computed above to
|
||||
// allow for zero-padding (if necessary/desired) along the far end
|
||||
// of each micropanel (ie: the right edge of the matrix). Zero-padding
|
||||
// can also occur along the long edge of the last micropanel if the m
|
||||
// dimension of the matrix is not a whole multiple of MR.
|
||||
auto ps_p = bmult_m_pack * n_p_pad;
|
||||
|
||||
/* Compute the total number of iterations we'll need. */
|
||||
auto n_iter = m_p_pad / bmult_m_def;
|
||||
|
||||
// Store the strides and panel dimension in P.
|
||||
bli_obj_set_strides( 1, bmult_m_pack, p );
|
||||
bli_obj_set_imag_stride( 1, p );
|
||||
bli_obj_set_panel_dim( bmult_m_def, p );
|
||||
bli_obj_set_panel_stride( ps_p, p );
|
||||
bli_obj_set_panel_length( bmult_m_def, p );
|
||||
bli_obj_set_panel_width( n_p, p );
|
||||
|
||||
// Compute the size of the packed buffer.
|
||||
auto size_p = ps_p * n_iter * dt_size;
|
||||
if ( size_p == 0 ) return;
|
||||
|
||||
// Compute the size of the scatter and block-scatter vectors to the total.
|
||||
// It is never necessary to add padding for alignment because:
|
||||
// 1) ps_p is always even
|
||||
// 2) dt_size is a power of two >= 4
|
||||
// 3) the alignment of the scatter vectors is at most 8
|
||||
auto scat_size = 2 * (m_p + n_p) * sizeof(inc_t);
|
||||
|
||||
// Update the buffer address in p to point to the buffer associated
|
||||
// with the mem_t entry acquired from the memory broker (now cached in
|
||||
// the control tree node).
|
||||
auto p_cast = (char*)bli_packm_alloc( size_p + scat_size, rntm, cntl, thread );
|
||||
bli_obj_set_buffer( p_cast, p );
|
||||
|
||||
// Get the addresses of the scatter and block-scatter vectors. These are
|
||||
// placed directly after the packed matrix buffer.
|
||||
auto rscat = (inc_t*)(p_cast + size_p);
|
||||
auto rbs = rscat + m_p;
|
||||
auto cscat = rbs + m_p;
|
||||
auto cbs = cscat + n_p;
|
||||
|
||||
auto a_cast = (char*)bli_obj_buffer_at_off( a );
|
||||
auto panel_dim_off = bli_obj_row_off( a );
|
||||
auto panel_len_off = bli_obj_col_off( a );
|
||||
auto conja = bli_obj_conj_status( a );
|
||||
|
||||
auto params = (packm_tensor_params_t*)bli_obj_pack_params( a );
|
||||
auto ndim_m = params->ndim_m;
|
||||
auto ndim_n = params->ndim_n;
|
||||
auto len_m = params->len_m;
|
||||
auto len_n = params->len_n;
|
||||
auto stride_m = params->stride_m;
|
||||
auto stride_n = params->stride_n;
|
||||
|
||||
obj_t kappa_local;
|
||||
auto kappa_cast = (char*)bli_packm_scalar( &kappa_local, p );
|
||||
|
||||
auto packm_nb_ker = packm_ckx_nb_ukrs[ dt ];
|
||||
auto packm_ss_ker = packm_ckx_ss_ukrs[ dt ];
|
||||
|
||||
a_cast -= ( panel_dim_off * stride_m[0] +
|
||||
panel_len_off * stride_n[0] ) * dt_size;
|
||||
|
||||
/* Fill in the scatter and block-scatter vectors. This is done single-threaded for now. */
|
||||
if ( bli_thread_am_ochief( thread ) )
|
||||
{
|
||||
fill_scatter
|
||||
(
|
||||
ndim_m,
|
||||
len_m,
|
||||
stride_m,
|
||||
bmult_m_def,
|
||||
panel_dim_off,
|
||||
m_p,
|
||||
rscat,
|
||||
rbs
|
||||
);
|
||||
|
||||
fill_scatter
|
||||
(
|
||||
ndim_n,
|
||||
len_n,
|
||||
stride_n,
|
||||
BS_K,
|
||||
panel_len_off,
|
||||
n_p,
|
||||
cscat,
|
||||
cbs
|
||||
);
|
||||
}
|
||||
|
||||
/* Wait for the scatter vectors to be done. */
|
||||
bli_thread_barrier( thread );
|
||||
|
||||
/* Query the number of threads and thread ids from the current thread's
|
||||
packm thrinfo_t node. */
|
||||
auto nt = bli_thread_n_way( thread );
|
||||
auto tid = bli_thread_work_id( thread );
|
||||
|
||||
/* Determine the thread range and increment using the current thread's
|
||||
packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir()
|
||||
will depend on whether slab or round-robin partitioning was requested
|
||||
at configure-time. */
|
||||
dim_t it_start, it_end, it_inc;
|
||||
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc );
|
||||
|
||||
/* Iterate over every logical micropanel in the source matrix. */
|
||||
for ( auto it = 0; it < n_iter; it += 1 )
|
||||
if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) )
|
||||
{
|
||||
auto panel_dim_i = bli_min( bmult_m_def, m_p - it*bmult_m_def );
|
||||
|
||||
auto p_begin = p_cast + it*ps_p*dt_size;
|
||||
auto inca = rbs[ it*bmult_m_def ];
|
||||
|
||||
if ( inca )
|
||||
{
|
||||
auto a_begin = a_cast + rscat[ it*bmult_m_def ]*dt_size;
|
||||
|
||||
packm_nb_ker( conja,
|
||||
panel_dim_i,
|
||||
n_p,
|
||||
bmult_m_def,
|
||||
n_p_pad,
|
||||
kappa_cast,
|
||||
a_begin, inca, cbs, cscat,
|
||||
p_begin, bmult_m_pack );
|
||||
}
|
||||
else
|
||||
{
|
||||
auto a_begin = a_cast;
|
||||
auto rscat_use = rscat + it*bmult_m_def;
|
||||
|
||||
packm_ss_ker( conja,
|
||||
panel_dim_i,
|
||||
n_p,
|
||||
bmult_m_def,
|
||||
n_p_pad,
|
||||
kappa_cast,
|
||||
a_begin, rscat_use, cscat,
|
||||
p_begin, bmult_m_pack );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#undef GENTFUNC
|
||||
#define GENTFUNC(ctype,ch,op) \
|
||||
void PASTEMAC(ch,op) \
|
||||
( \
|
||||
dim_t m, \
|
||||
dim_t n, \
|
||||
void* x, inc_t rs_x, inc_t cs_x, \
|
||||
void* b, \
|
||||
void* y, inc_t* rs_y, inc_t* cs_y \
|
||||
) \
|
||||
{ \
|
||||
ctype* restrict x_cast = (ctype*)x; \
|
||||
ctype b_cast = *(ctype*)b; \
|
||||
ctype* restrict y_cast = (ctype*)y; \
|
||||
\
|
||||
if ( PASTEMAC(ch,eq0)( b_cast ) ) \
|
||||
{ \
|
||||
for ( auto i = 0; i < m; i++ ) \
|
||||
for ( auto j = 0; j < n; j++ ) \
|
||||
PASTEMAC(ch,copys)( x_cast[ i*rs_x + j*cs_x ], y_cast[ rs_y[i] + cs_y[j] ] ); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
for ( auto i = 0; i < m; i++ ) \
|
||||
for ( auto j = 0; j < n; j++ ) \
|
||||
PASTEMAC(ch,xpbys)( x_cast[ i*rs_x + j*cs_x ], b_cast, y_cast[ rs_y[i] + cs_y[j] ] ); \
|
||||
} \
|
||||
}
|
||||
|
||||
INSERT_GENTFUNC_BASIC0(scatter_mxn);
|
||||
|
||||
static decltype(&bli_sscatter_mxn) GENARRAY(scatter_mxn, scatter_mxn);
|
||||
|
||||
void gemm_tensor
|
||||
(
|
||||
obj_t* a,
|
||||
obj_t* b,
|
||||
obj_t* c,
|
||||
cntx_t* cntx,
|
||||
rntm_t* rntm,
|
||||
cntl_t* cntl,
|
||||
thrinfo_t* thread
|
||||
)
|
||||
{
|
||||
auto dt = bli_obj_dt( c );
|
||||
auto dt_size = bli_dt_size( dt );
|
||||
|
||||
auto m = bli_obj_length( c );
|
||||
auto n = bli_obj_width( c );
|
||||
auto k = bli_obj_width( a );
|
||||
|
||||
auto a_cast = (char*)bli_obj_buffer_at_off( a );
|
||||
auto pd_a = bli_obj_panel_dim( a );
|
||||
auto ps_a = bli_obj_panel_stride( a );
|
||||
|
||||
auto b_cast = (char*)bli_obj_buffer_at_off( b );
|
||||
auto pd_b = bli_obj_panel_dim( b );
|
||||
auto ps_b = bli_obj_panel_stride( b );
|
||||
|
||||
auto c_cast = (char*)bli_obj_buffer_at_off( c );
|
||||
auto rs_c0 = bli_obj_row_stride( c );
|
||||
auto cs_c0 = bli_obj_col_stride( c );
|
||||
auto off_m = bli_obj_row_off( c );
|
||||
auto off_n = bli_obj_col_off( c );
|
||||
|
||||
auto params = (gemm_tensor_params_t*)bli_obj_ker_params( c );
|
||||
auto ndim_m = params->ndim_m;
|
||||
auto ndim_n = params->ndim_n;
|
||||
auto len_m = params->len_m;
|
||||
auto len_n = params->len_n;
|
||||
auto stride_m = params->stride_m;
|
||||
auto stride_n = params->stride_n;
|
||||
|
||||
if ( rs_c0 != stride_m[0] || cs_c0 != stride_n[0] )
|
||||
{
|
||||
std::swap( ndim_m, ndim_n );
|
||||
std::swap( len_m, len_n );
|
||||
std::swap( stride_m, stride_n );
|
||||
}
|
||||
|
||||
/* If any dimension is zero, return immediately. */
|
||||
if ( bli_zero_dim3( m, n, k ) ) return;
|
||||
|
||||
c_cast -= ( off_m * stride_m[0] +
|
||||
off_n * stride_n[0] ) * dt_size;
|
||||
|
||||
// Detach and multiply the scalars attached to A and B.
|
||||
// NOTE: We know that the internal scalars of A and B are already of the
|
||||
// target datatypes because the necessary typecasting would have already
|
||||
// taken place during bli_packm_init().
|
||||
obj_t scalar_a;
|
||||
obj_t scalar_b;
|
||||
bli_obj_scalar_detach( a, &scalar_a );
|
||||
bli_obj_scalar_detach( b, &scalar_b );
|
||||
bli_mulsc( &scalar_a, &scalar_b );
|
||||
|
||||
// Grab the addresses of the internal scalar buffers for the scalar
|
||||
// merged above and the scalar attached to C.
|
||||
// NOTE: We know that scalar_b is of type dt due to the above code
|
||||
// that casts the scalars of A and B to dt via scalar_a and scalar_b,
|
||||
// and we know that the internal scalar in C is already of the type dt
|
||||
// due to the casting in the implementation of bli_obj_scalar_attach().
|
||||
auto alpha_cast = (char*)bli_obj_internal_scalar_buffer( &scalar_b );
|
||||
auto beta_cast = (char*)bli_obj_internal_scalar_buffer( c );
|
||||
|
||||
/* Alias some constants to simpler names. */
|
||||
auto MR = pd_a;
|
||||
auto NR = pd_b;
|
||||
|
||||
/* Query the context for the micro-kernel address and cast it to its
|
||||
function pointer type. */
|
||||
auto gemm_ukr = (gemm_ukr_vft)bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx );
|
||||
|
||||
/* Temporary C buffer for edge cases. Note that the strides of this
|
||||
temporary buffer are set so that they match the storage of the
|
||||
original C matrix. For example, if C is column-stored, ct will be
|
||||
column-stored as well. */
|
||||
char ct[ BLIS_STACK_BUF_MAX_SIZE ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE)));
|
||||
auto col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx );
|
||||
auto rs_ct = ( col_pref ? 1 : NR );
|
||||
auto cs_ct = ( col_pref ? MR : 1 );
|
||||
auto zero = (char*)bli_obj_buffer_for_const( dt, &BLIS_ZERO );
|
||||
|
||||
/*
|
||||
Assumptions/assertions:
|
||||
rs_a == 1
|
||||
cs_a == PACKMR
|
||||
pd_a == MR
|
||||
ps_a == stride to next micro-panel of A
|
||||
rs_b == PACKNR
|
||||
cs_b == 1
|
||||
pd_b == NR
|
||||
ps_b == stride to next micro-panel of B
|
||||
rs_c == (no assumptions)
|
||||
cs_c == (no assumptions)
|
||||
*/
|
||||
|
||||
auto scat_size = 2 * (m + n) * sizeof(inc_t);
|
||||
auto rscat_c = (inc_t*)bli_packm_alloc_ex( scat_size, BLIS_BUFFER_FOR_GEN_USE, rntm, cntl, thread );
|
||||
auto rbs_c = rscat_c + m;
|
||||
auto cscat_c = rbs_c + m;
|
||||
auto cbs_c = cscat_c + n;
|
||||
|
||||
/* Fill in the scatter and block-scatter vectors. This is done single-threaded for now. */
|
||||
if ( bli_thread_am_ochief( thread ) )
|
||||
{
|
||||
fill_scatter
|
||||
(
|
||||
ndim_m,
|
||||
len_m,
|
||||
stride_m,
|
||||
MR,
|
||||
off_m,
|
||||
m,
|
||||
rscat_c,
|
||||
rbs_c
|
||||
);
|
||||
|
||||
fill_scatter
|
||||
(
|
||||
ndim_n,
|
||||
len_n,
|
||||
stride_n,
|
||||
NR,
|
||||
off_n,
|
||||
n,
|
||||
cscat_c,
|
||||
cbs_c
|
||||
);
|
||||
}
|
||||
|
||||
/* Wait for the scatter vectors to be done. */
|
||||
bli_thread_barrier( thread );
|
||||
|
||||
/* Compute number of primary and leftover components of the m and n
|
||||
dimensions. */
|
||||
auto n_iter = n / NR;
|
||||
auto n_left = n % NR;
|
||||
|
||||
auto m_iter = m / MR;
|
||||
auto m_left = m % MR;
|
||||
|
||||
if ( n_left ) ++n_iter;
|
||||
if ( m_left ) ++m_iter;
|
||||
|
||||
/* Determine some increments used to step through A, B, and C. */
|
||||
auto rstep_a = ps_a * dt_size;
|
||||
auto cstep_b = ps_b * dt_size;
|
||||
|
||||
/* Save the virtual microkernel address and the params. */
|
||||
auxinfo_t aux;
|
||||
bli_auxinfo_set_ukr( (void*)gemm_ukr, &aux );
|
||||
bli_auxinfo_set_params( params, &aux );
|
||||
|
||||
/* The 'thread' argument points to the thrinfo_t node for the 2nd (jr)
|
||||
loop around the microkernel. Here we query the thrinfo_t node for the
|
||||
1st (ir) loop around the microkernel. */
|
||||
auto caucus = bli_thrinfo_sub_node( thread );
|
||||
|
||||
/* Query the number of threads and thread ids for each loop. */
|
||||
auto jr_nt = bli_thread_n_way( thread );
|
||||
auto jr_tid = bli_thread_work_id( thread );
|
||||
auto ir_nt = bli_thread_n_way( caucus );
|
||||
auto ir_tid = bli_thread_work_id( caucus );
|
||||
|
||||
/* Determine the thread range and increment for the 2nd and 1st loops.
|
||||
NOTE: The definition of bli_thread_range_jrir() will depend on whether
|
||||
slab or round-robin partitioning was requested at configure-time. */
|
||||
dim_t jr_start, jr_end;
|
||||
dim_t ir_start, ir_end;
|
||||
dim_t jr_inc, ir_inc;
|
||||
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc );
|
||||
bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc );
|
||||
|
||||
/* Loop over the n dimension (NR columns at a time). */
|
||||
for ( auto j = jr_start; j < jr_end; j += jr_inc )
|
||||
{
|
||||
auto b1 = b_cast + j * cstep_b;
|
||||
|
||||
auto n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left );
|
||||
|
||||
/* Initialize our next panel of B to be the current panel of B. */
|
||||
auto b2 = b1;
|
||||
|
||||
/* Loop over the m dimension (MR rows at a time). */
|
||||
for ( auto i = ir_start; i < ir_end; i += ir_inc )
|
||||
{
|
||||
auto a1 = a_cast + i * rstep_a;
|
||||
auto rscat_c1 = rscat_c + i * MR;
|
||||
auto rbs_c1 = rbs_c + i * MR;
|
||||
auto cscat_c1 = cscat_c + j * NR;
|
||||
auto cbs_c1 = cbs_c + j * NR;
|
||||
|
||||
auto m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left );
|
||||
|
||||
/* Compute the addresses of the next panels of A and B. */
|
||||
auto a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc );
|
||||
if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) )
|
||||
{
|
||||
a2 = a_cast;
|
||||
b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc );
|
||||
if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) )
|
||||
b2 = b_cast;
|
||||
}
|
||||
|
||||
/* Save addresses of next panels of A and B to the auxinfo_t
|
||||
object. */
|
||||
bli_auxinfo_set_next_a( a2, &aux );
|
||||
bli_auxinfo_set_next_b( b2, &aux );
|
||||
|
||||
auto rs_c = *rbs_c1;
|
||||
auto cs_c = *cbs_c1;
|
||||
|
||||
if ( rs_c && cs_c )
|
||||
{
|
||||
auto c11 = c_cast + ( *rscat_c1 + *cscat_c1 ) * dt_size;
|
||||
|
||||
/* Invoke the gemm micro-kernel. */
|
||||
gemm_ukr
|
||||
(
|
||||
m_cur,
|
||||
n_cur,
|
||||
k,
|
||||
alpha_cast,
|
||||
a1,
|
||||
b1,
|
||||
beta_cast,
|
||||
c11, rs_c, cs_c,
|
||||
&aux,
|
||||
cntx
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
/* Invoke the gemm micro-kernel. */
|
||||
gemm_ukr
|
||||
(
|
||||
MR,
|
||||
NR,
|
||||
k,
|
||||
alpha_cast,
|
||||
a1,
|
||||
b1,
|
||||
zero,
|
||||
&ct, rs_ct, cs_ct,
|
||||
&aux,
|
||||
cntx
|
||||
);
|
||||
|
||||
/* Scatter to C. */
|
||||
scatter_mxn[ dt ]
|
||||
(
|
||||
m_cur, n_cur,
|
||||
&ct, rs_ct, cs_ct,
|
||||
beta_cast,
|
||||
c_cast, rscat_c1, cscat_c1
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static bool has_unit_stride( const std::vector<inc_t>& stride )
|
||||
{
|
||||
for ( auto s : stride )
|
||||
if ( s == 1 )
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
void tcontract( num_t dt, const std::vector<dim_t>& m, const std::vector<dim_t>& n, const std::vector<dim_t>& k,
|
||||
const void* alpha, const void* a, std::vector<inc_t> rs_a, std::vector<inc_t> cs_a,
|
||||
const void* b, std::vector<inc_t> rs_b, std::vector<inc_t> cs_b,
|
||||
const void* beta, void* c, std::vector<inc_t> rs_c, std::vector<inc_t> cs_c )
|
||||
{
|
||||
if ( rs_a.size() != m.size() ||
|
||||
rs_b.size() != k.size() ||
|
||||
rs_c.size() != m.size() )
|
||||
bli_check_error_code( BLIS_INVALID_ROW_STRIDE );
|
||||
|
||||
if ( cs_a.size() != k.size() ||
|
||||
cs_b.size() != n.size() ||
|
||||
cs_c.size() != n.size() )
|
||||
bli_check_error_code( BLIS_INVALID_COL_STRIDE );
|
||||
|
||||
dim_t m_mat = 1;
|
||||
dim_t n_mat = 1;
|
||||
dim_t k_mat = 1;
|
||||
for ( auto& i : m ) m_mat *= i;
|
||||
for ( auto& i : n ) n_mat *= i;
|
||||
for ( auto& i : k ) k_mat *= i;
|
||||
|
||||
auto& stride_m = has_unit_stride( rs_c ) ? rs_c : rs_a;
|
||||
for ( int i = 1;i < m.size(); i++ )
|
||||
for ( int j = 0;j < m.size()-i; j++ )
|
||||
if ( stride_m[j] > stride_m[j+1] )
|
||||
{
|
||||
std::swap( rs_a[j], rs_a[j+1] );
|
||||
std::swap( rs_c[j], rs_c[j+1] );
|
||||
}
|
||||
|
||||
auto& stride_n = has_unit_stride( cs_c ) ? cs_c : cs_b;
|
||||
for ( int i = 1;i < n.size(); i++ )
|
||||
for ( int j = 0;j < n.size()-i; j++ )
|
||||
if ( stride_n[j] > stride_n[j+1] )
|
||||
{
|
||||
std::swap( cs_b[j], cs_b[j+1] );
|
||||
std::swap( cs_c[j], cs_c[j+1] );
|
||||
}
|
||||
|
||||
auto& stride_k = has_unit_stride( cs_a ) ? cs_a : rs_b;
|
||||
for ( int i = 1;i < k.size(); i++ )
|
||||
for ( int j = 0;j < k.size()-i; j++ )
|
||||
if ( stride_k[j] > stride_k[j+1] )
|
||||
{
|
||||
std::swap( cs_a[j], cs_a[j+1] );
|
||||
std::swap( rs_b[j], rs_b[j+1] );
|
||||
}
|
||||
|
||||
if ( rs_a.empty() ) rs_a.push_back( 1 );
|
||||
if ( cs_a.empty() ) cs_a.push_back( 1 );
|
||||
if ( rs_b.empty() ) rs_b.push_back( 1 );
|
||||
if ( cs_b.empty() ) cs_b.push_back( 1 );
|
||||
if ( rs_c.empty() ) rs_c.push_back( 1 );
|
||||
if ( cs_c.empty() ) cs_c.push_back( 1 );
|
||||
|
||||
obj_t a_o, b_o, c_o;
|
||||
bli_obj_create_with_attached_buffer( dt, m_mat, k_mat, const_cast<void*>(a), rs_a[0], cs_a[0], &a_o );
|
||||
bli_obj_create_with_attached_buffer( dt, k_mat, n_mat, const_cast<void*>(b), rs_b[0], cs_b[0], &b_o );
|
||||
bli_obj_create_with_attached_buffer( dt, m_mat, n_mat, c , rs_c[0], cs_c[0], &c_o );
|
||||
|
||||
packm_tensor_params_t params_a( m.size(), m.data(), rs_a.data(),
|
||||
k.size(), k.data(), cs_a.data() );
|
||||
packm_tensor_params_t params_b( n.size(), n.data(), cs_b.data(),
|
||||
k.size(), k.data(), rs_b.data() );
|
||||
gemm_tensor_params_t params_c( m.size(), m.data(), rs_c.data(),
|
||||
n.size(), n.data(), cs_c.data() );
|
||||
|
||||
bli_obj_set_pack_fn( packm_tensor, &a_o );
|
||||
bli_obj_set_pack_fn( packm_tensor, &b_o );
|
||||
bli_obj_set_ker_fn( gemm_tensor, &c_o );
|
||||
bli_obj_set_pack_params( ¶ms_a, &a_o );
|
||||
bli_obj_set_pack_params( ¶ms_b, &b_o );
|
||||
bli_obj_set_ker_params( ¶ms_c, &c_o );
|
||||
|
||||
obj_t alpha_o, beta_o;
|
||||
bli_obj_create_1x1_with_attached_buffer( dt, const_cast<void*>(alpha), &alpha_o );
|
||||
bli_obj_create_1x1_with_attached_buffer( dt, const_cast<void*>(beta), &beta_o );
|
||||
|
||||
rntm_t rntm;
|
||||
bli_rntm_init_from_global( &rntm );
|
||||
bli_rntm_disable_l3_sup( &rntm );
|
||||
|
||||
bli_gemm_ex( &alpha_o, &a_o, &b_o, &beta_o, &c_o, NULL, &rntm );
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
auto N = 5;
|
||||
|
||||
gint_t ndim_a = 4;
|
||||
gint_t ndim_b = 4;
|
||||
gint_t ndim_c = 4;
|
||||
|
||||
std::vector<dim_t> len_a(ndim_a, N);
|
||||
std::vector<dim_t> len_b(ndim_b, N);
|
||||
std::vector<dim_t> len_c(ndim_c, N);
|
||||
|
||||
std::vector<inc_t> stride_a(ndim_a, 1);
|
||||
std::vector<inc_t> stride_b(ndim_b, 1);
|
||||
std::vector<inc_t> stride_c(ndim_c, 1);
|
||||
for ( gint_t i = 1; i < ndim_a; i++ )
|
||||
stride_a[i] = stride_a[i-1] * len_a[i - 1];
|
||||
for ( gint_t i = 1; i < ndim_b; i++ )
|
||||
stride_b[i] = stride_b[i-1] * len_b[i - 1];
|
||||
for ( gint_t i = 1; i < ndim_c; i++ )
|
||||
stride_c[i] = stride_c[i-1] * len_c[i - 1];
|
||||
|
||||
std::vector<int> dim_a(ndim_a);
|
||||
std::vector<int> dim_b(ndim_b);
|
||||
std::vector<int> dim_c(ndim_c);
|
||||
std::iota(dim_a.begin(), dim_a.end(), 0);
|
||||
std::iota(dim_b.begin(), dim_b.end(), 0);
|
||||
std::iota(dim_c.begin(), dim_c.end(), 0);
|
||||
|
||||
for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ )
|
||||
do
|
||||
do
|
||||
do
|
||||
{
|
||||
auto dt = ( num_t )dt_;
|
||||
|
||||
auto ndim_m = (ndim_a + ndim_c - ndim_b)/2;
|
||||
auto ndim_k = (ndim_a + ndim_b - ndim_c)/2;
|
||||
|
||||
std::vector<dim_t> m(len_a.begin(), len_a.begin()+ndim_m);
|
||||
std::vector<dim_t> n(len_b.begin()+ndim_k, len_b.end());
|
||||
std::vector<dim_t> k(len_b.begin(), len_b.begin()+ndim_k);
|
||||
|
||||
std::vector<inc_t> rs_a(stride_a.begin(), stride_a.begin()+ndim_m);
|
||||
std::vector<inc_t> cs_a(stride_a.begin()+ndim_m, stride_a.end());
|
||||
std::vector<inc_t> rs_b(stride_b.begin(), stride_b.begin()+ndim_k);
|
||||
std::vector<inc_t> cs_b(stride_b.begin()+ndim_k, stride_b.end());
|
||||
std::vector<inc_t> rs_c(stride_c.begin(), stride_c.begin()+ndim_m);
|
||||
std::vector<inc_t> cs_c(stride_c.begin()+ndim_m, stride_c.end());
|
||||
|
||||
dim_t m_tot = 1;
|
||||
dim_t n_tot = 1;
|
||||
dim_t k_tot = 1;
|
||||
for ( auto i : m ) m_tot *= i;
|
||||
for ( auto i : n ) n_tot *= i;
|
||||
for ( auto i : k ) k_tot *= i;
|
||||
|
||||
obj_t a, b, c, c_ref, norm;
|
||||
|
||||
bli_obj_create( dt, m_tot*k_tot, 1, 1, 1, &a );
|
||||
bli_obj_create( dt, k_tot*n_tot, 1, 1, 1, &b );
|
||||
bli_obj_create( dt, m_tot*n_tot, 1, 1, 1, &c );
|
||||
bli_obj_create( dt, m_tot*n_tot, 1, 1, 1, &c_ref );
|
||||
bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm );
|
||||
|
||||
bli_randv( &a );
|
||||
bli_randv( &b );
|
||||
bli_randv( &c );
|
||||
bli_copyv( &c, &c_ref );
|
||||
|
||||
tcontract( dt, m, n, k,
|
||||
bli_obj_buffer_for_const( dt, &BLIS_ONE ),
|
||||
bli_obj_buffer( &a ), rs_a, cs_a,
|
||||
bli_obj_buffer( &b ), rs_b, cs_b,
|
||||
bli_obj_buffer_for_const( dt, &BLIS_ZERO ),
|
||||
bli_obj_buffer( &c ), rs_c, cs_c );
|
||||
|
||||
tcontract_ref( dt, m, n, k,
|
||||
bli_obj_buffer_for_const( dt, &BLIS_ONE ),
|
||||
bli_obj_buffer( &a ), rs_a, cs_a,
|
||||
bli_obj_buffer( &b ), rs_b, cs_b,
|
||||
bli_obj_buffer_for_const( dt, &BLIS_ZERO ),
|
||||
bli_obj_buffer( &c_ref ), rs_c, cs_c );
|
||||
|
||||
bli_subv( &c_ref, &c );
|
||||
bli_normfv( &c, &norm );
|
||||
|
||||
double normr, normi;
|
||||
bli_getsc( &norm, &normr, &normi );
|
||||
|
||||
printf("dt: %d, dim_a: [%d,%d,%d,%d], dim_b: [%d,%d,%d,%d], dim_c: [%d,%d,%d,%d], norm: %g\n",
|
||||
dt, dim_a[0], dim_a[1], dim_a[2], dim_a[3],
|
||||
dim_b[0], dim_b[1], dim_b[2], dim_b[3],
|
||||
dim_c[0], dim_c[1], dim_c[2], dim_c[3],
|
||||
normr / std::sqrt( bli_obj_vector_dim( &c ) ) );
|
||||
|
||||
bli_obj_free( &a );
|
||||
bli_obj_free( &b );
|
||||
bli_obj_free( &c );
|
||||
bli_obj_free( &c_ref );
|
||||
}
|
||||
while (std::next_permutation(dim_a.begin(), dim_a.end()));
|
||||
while (std::next_permutation(dim_b.begin(), dim_b.end()));
|
||||
while (std::next_permutation(dim_c.begin(), dim_c.end()));
|
||||
}
|
||||
|
||||
67
test/tensor_contraction/tcontract_ref.cxx
Normal file
67
test/tensor_contraction/tcontract_ref.cxx
Normal file
@@ -0,0 +1,67 @@
|
||||
#include "tcontract_ref.hpp"
|
||||
|
||||
template <typename T>
|
||||
void tcontract_ref( const std::vector<dim_t>& m, const std::vector<dim_t>& n, const std::vector<dim_t>& k,
|
||||
const void* alpha, const void* a, const std::vector<inc_t>& rs_a, const std::vector<inc_t>& cs_a,
|
||||
const void* b, const std::vector<inc_t>& rs_b, const std::vector<inc_t>& cs_b,
|
||||
const void* beta, void* c, const std::vector<inc_t>& rs_c, const std::vector<inc_t>& cs_c )
|
||||
{
|
||||
auto alpha_cast = *( T* )alpha;
|
||||
auto beta_cast = *( T* )beta;
|
||||
auto a_cast = ( T* )a;
|
||||
auto b_cast = ( T* )b;
|
||||
auto c_cast = ( T* )c;
|
||||
|
||||
for_each(m.size(), m.data(), a_cast, rs_a.data(), c_cast, rs_c.data(),
|
||||
[&]
|
||||
{
|
||||
for_each(n.size(), n.data(), b_cast, cs_b.data(), c_cast, cs_c.data(),
|
||||
[&]
|
||||
{
|
||||
auto ab = convert<T>(0.0);
|
||||
|
||||
for_each(k.size(), k.data(), a_cast, cs_a.data(), b_cast, rs_b.data(),
|
||||
[&]
|
||||
{
|
||||
ab += (*a_cast) * (*b_cast);
|
||||
});
|
||||
|
||||
if ( beta_cast == convert<T>(0.0) )
|
||||
{
|
||||
*c_cast = alpha_cast * ab;
|
||||
}
|
||||
else
|
||||
{
|
||||
*c_cast = alpha_cast * ab + beta_cast * (*c_cast);
|
||||
}
|
||||
});
|
||||
|
||||
assert(b_cast == b);
|
||||
});
|
||||
|
||||
assert(a_cast == a);
|
||||
assert(c_cast == c);
|
||||
}
|
||||
|
||||
#undef GENTFUNC
|
||||
#define GENTFUNC(ctype,ch,op) \
|
||||
static auto PASTEMAC(ch,op) = &tcontract_ref<ctype>;
|
||||
|
||||
INSERT_GENTFUNC_BASIC0(tcontract_ref);
|
||||
|
||||
static decltype(&tcontract_ref<void>) GENARRAY( tcontract_ref_impl, tcontract_ref );
|
||||
|
||||
void tcontract_ref( num_t dt, const std::vector<dim_t>& m, const std::vector<dim_t>& n, const std::vector<dim_t>& k,
|
||||
const void* alpha, const void* a, const std::vector<inc_t>& rs_a, const std::vector<inc_t>& cs_a,
|
||||
const void* b, const std::vector<inc_t>& rs_b, const std::vector<inc_t>& cs_b,
|
||||
const void* beta, void* c, const std::vector<inc_t>& rs_c, const std::vector<inc_t>& cs_c )
|
||||
{
|
||||
tcontract_ref_impl[ dt ]
|
||||
(
|
||||
m, n, k,
|
||||
alpha, a, rs_a, cs_a,
|
||||
b, rs_b, cs_b,
|
||||
beta, c, rs_c, cs_c
|
||||
);
|
||||
}
|
||||
|
||||
100
test/tensor_contraction/tcontract_ref.hpp
Normal file
100
test/tensor_contraction/tcontract_ref.hpp
Normal file
@@ -0,0 +1,100 @@
|
||||
#include "blis.h"
|
||||
#include "complex_math.hpp"
|
||||
|
||||
#include <vector>
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
|
||||
inline void increment(inc_t, gint_t) {}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
void increment(inc_t n, gint_t i, T& off, const inc_t* s, Args&... args)
|
||||
{
|
||||
off += s[i]*n;
|
||||
increment(n, i, args...);
|
||||
}
|
||||
|
||||
template <typename Body, typename... Args>
|
||||
void for_each_impl(gint_t ndim, const dim_t* n,
|
||||
dim_t off, dim_t len,
|
||||
Body& body,
|
||||
Args&... args)
|
||||
{
|
||||
std::array<dim_t,8> i = {};
|
||||
assert( ndim <= i.size() );
|
||||
|
||||
if ( off )
|
||||
{
|
||||
for ( gint_t k = 0; k < ndim; k++ )
|
||||
{
|
||||
i[k] = off % n[k];
|
||||
off /= n[k];
|
||||
increment(i[k], k, args...);
|
||||
}
|
||||
}
|
||||
|
||||
for ( dim_t pos = 0; pos < len; pos++ )
|
||||
{
|
||||
body();
|
||||
|
||||
for ( gint_t k = 0; k < ndim; k++ )
|
||||
{
|
||||
if ( i[k] == n[k]-1 )
|
||||
{
|
||||
increment(-i[k], k, args...);
|
||||
i[k] = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
increment(1, k, args...);
|
||||
i[k]++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Body>
|
||||
void for_each(gint_t ndim, const dim_t* n,
|
||||
dim_t off, dim_t len,
|
||||
T& a, const inc_t* s_a,
|
||||
Body&& body)
|
||||
{
|
||||
for_each_impl( ndim, n, off, len, body, a, s_a );
|
||||
}
|
||||
|
||||
template <typename T, typename Body>
|
||||
void for_each(gint_t ndim, const dim_t* n,
|
||||
dim_t off, dim_t len,
|
||||
T& a, const inc_t* s_a,
|
||||
T& b, const inc_t* s_b,
|
||||
Body&& body)
|
||||
{
|
||||
for_each_impl( ndim, n, off, len, body, a, s_a, b, s_b );
|
||||
}
|
||||
|
||||
template <typename T, typename Body>
|
||||
void for_each(gint_t ndim, const dim_t* n,
|
||||
T& a, const inc_t* s_a,
|
||||
Body&& body)
|
||||
{
|
||||
dim_t len = 1;
|
||||
for ( gint_t i = 0;i < ndim;i++ ) len *= n[i];
|
||||
for_each_impl( ndim, n, 0, len, body, a, s_a );
|
||||
}
|
||||
|
||||
template <typename T, typename Body>
|
||||
void for_each(gint_t ndim, const dim_t* n,
|
||||
T& a, const inc_t* s_a,
|
||||
T& b, const inc_t* s_b,
|
||||
Body&& body)
|
||||
{
|
||||
dim_t len = 1;
|
||||
for ( gint_t i = 0;i < ndim;i++ ) len *= n[i];
|
||||
for_each_impl( ndim, n, 0, len, body, a, s_a, b, s_b );
|
||||
}
|
||||
|
||||
void tcontract_ref( num_t dt, const std::vector<dim_t>& m, const std::vector<dim_t>& n, const std::vector<dim_t>& k,
|
||||
const void* alpha, const void* a, const std::vector<inc_t>& rs_a, const std::vector<inc_t>& cs_a,
|
||||
const void* b, const std::vector<inc_t>& rs_b, const std::vector<inc_t>& cs_b,
|
||||
const void* beta, void* c, const std::vector<inc_t>& rs_c, const std::vector<inc_t>& cs_c );
|
||||
Reference in New Issue
Block a user