mirror of
https://github.com/amd/blis.git
synced 2026-05-13 02:25:39 +00:00
Move edge cases to gemm ukr; more user-custom mods. (#583)
Details: - Moved edge-case handling into the gemm microkernel. This required changing the microkernel API to take m and n dimension parameters. This required updating all existing gemm microkernel function pointer types, function signatures, and related definitions to take m and n dimensions. We also updated all existing kernels in the 'kernels' directory to take m and n dimensions, and implemented edge-case handling within those microkernels via a collection of new C preprocessor macros defined within bli_edge_case_macro_defs.h. Also removed the assembly code that formerly would handle general stride IO on the microtile, since this can now be handled by the same code that does edge cases. - Pass the obj_t.ker_fn (of matrix C) into bli_gemm_cntl_create() and bli_trsm_cntl_create(), where this function pointer is used in lieu of the default macrokernel when it is non-NULL, and ignored when it is NULL. - Re-implemented macrokernel in bli_gemm_ker_var2.c to be a single function using byte pointers rather that one function for each floating-point datatype. Also, obtain the microkernel function pointer from the .ukr field of the params struct embedded within the obj_t for matrix C (assuming params is non-NULL and contains a non-NULL value in the .ukr field). Communicate both the gemm microkernel pointer to use as well as the params struct to the microkernel via the auxinfo_t struct. - Defined gemm_ker_params_t type (for the aforementioned obj_t.params struct) in bli_gemm_var.h. - Retired the separate _md macrokernel for mixed datatype computation. We now use the reimplemented bli_gemm_ker_var2() instead. - Updated gemmt macrokernels to pass m and n dimensions into microkernel calls. - Removed edge-case handling from trmm and trsm macrokernels. - Moved most of bli_packm_alloc() code into a new helper function, bli_packm_alloc_ex(). - Fixed a typo bug in bli_gemmtrsm_u_template_noopt_mxn.c. - Added test/syrk_diagonal and test/tensor_contraction directories with associated code to test those operations.
This commit is contained in:
@@ -37,6 +37,8 @@
|
|||||||
|
|
||||||
void bli_zgemm_template_noopt
|
void bli_zgemm_template_noopt
|
||||||
(
|
(
|
||||||
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
dim_t k,
|
dim_t k,
|
||||||
dcomplex* restrict alpha,
|
dcomplex* restrict alpha,
|
||||||
dcomplex* restrict a1,
|
dcomplex* restrict a1,
|
||||||
@@ -88,8 +90,7 @@ void bli_zgemm_template_noopt
|
|||||||
|
|
||||||
dim_t l, j, i;
|
dim_t l, j, i;
|
||||||
|
|
||||||
dcomplex ab[ bli_zmr *
|
dcomplex ab[ mr * nr ];
|
||||||
bli_znr ];
|
|
||||||
dcomplex* abij;
|
dcomplex* abij;
|
||||||
dcomplex ai, bj;
|
dcomplex ai, bj;
|
||||||
|
|
||||||
@@ -137,16 +138,16 @@ void bli_zgemm_template_noopt
|
|||||||
if ( bli_zeq0( *beta ) )
|
if ( bli_zeq0( *beta ) )
|
||||||
{
|
{
|
||||||
/* c11 := ab */
|
/* c11 := ab */
|
||||||
bli_zcopys_mxn( mr,
|
bli_zcopys_mxn( m,
|
||||||
nr,
|
n,
|
||||||
ab, rs_ab, cs_ab,
|
ab, rs_ab, cs_ab,
|
||||||
c11, rs_c, cs_c );
|
c11, rs_c, cs_c );
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
/* c11 := beta * c11 + ab */
|
/* c11 := beta * c11 + ab */
|
||||||
bli_zxpbys_mxn( mr,
|
bli_zxpbys_mxn( m,
|
||||||
nr,
|
n,
|
||||||
ab, rs_ab, cs_ab,
|
ab, rs_ab, cs_ab,
|
||||||
beta,
|
beta,
|
||||||
c11, rs_c, cs_c );
|
c11, rs_c, cs_c );
|
||||||
|
|||||||
@@ -74,6 +74,8 @@ void bli_zgemmtrsm_l_template_noopt
|
|||||||
*/
|
*/
|
||||||
const num_t dt = BLIS_DCOMPLEX;
|
const num_t dt = BLIS_DCOMPLEX;
|
||||||
|
|
||||||
|
const inc_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx );
|
||||||
|
const inc_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx );
|
||||||
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx );
|
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx );
|
||||||
|
|
||||||
const inc_t rs_b = packnr;
|
const inc_t rs_b = packnr;
|
||||||
@@ -84,6 +86,8 @@ void bli_zgemmtrsm_l_template_noopt
|
|||||||
/* b11 = alpha * b11 - a10 * b01; */
|
/* b11 = alpha * b11 - a10 * b01; */
|
||||||
bli_zgemm_template_noopt
|
bli_zgemm_template_noopt
|
||||||
(
|
(
|
||||||
|
mr,
|
||||||
|
nr,
|
||||||
k,
|
k,
|
||||||
minus_one,
|
minus_one,
|
||||||
a10,
|
a10,
|
||||||
|
|||||||
@@ -74,6 +74,8 @@ void bli_zgemmtrsm_u_template_noopt
|
|||||||
*/
|
*/
|
||||||
const num_t dt = BLIS_DCOMPLEX;
|
const num_t dt = BLIS_DCOMPLEX;
|
||||||
|
|
||||||
|
const inc_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx );
|
||||||
|
const inc_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx );
|
||||||
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx );
|
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx );
|
||||||
|
|
||||||
const inc_t rs_b = packnr;
|
const inc_t rs_b = packnr;
|
||||||
@@ -84,10 +86,12 @@ void bli_zgemmtrsm_u_template_noopt
|
|||||||
/* b11 = alpha * b11 - a12 * b21; */
|
/* b11 = alpha * b11 - a12 * b21; */
|
||||||
bli_zgemm_template_noopt
|
bli_zgemm_template_noopt
|
||||||
(
|
(
|
||||||
|
mr,
|
||||||
|
nr,
|
||||||
k,
|
k,
|
||||||
minus_one,
|
minus_one,
|
||||||
a12,
|
a10,
|
||||||
b21,
|
b01,
|
||||||
alpha,
|
alpha,
|
||||||
b11, rs_b, cs_b,
|
b11, rs_b, cs_b,
|
||||||
data
|
data
|
||||||
|
|||||||
@@ -36,16 +36,35 @@
|
|||||||
#include "blis.h"
|
#include "blis.h"
|
||||||
|
|
||||||
void* bli_packm_alloc
|
void* bli_packm_alloc
|
||||||
(
|
(
|
||||||
siz_t size_needed,
|
siz_t size_needed,
|
||||||
rntm_t* rntm,
|
rntm_t* rntm,
|
||||||
cntl_t* cntl,
|
cntl_t* cntl,
|
||||||
thrinfo_t* thread
|
thrinfo_t* thread
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
// Query the pack buffer type from the control tree node.
|
// Query the pack buffer type from the control tree node.
|
||||||
packbuf_t pack_buf_type = bli_cntl_packm_params_pack_buf_type( cntl );
|
packbuf_t pack_buf_type = bli_cntl_packm_params_pack_buf_type( cntl );
|
||||||
|
|
||||||
|
return bli_packm_alloc_ex
|
||||||
|
(
|
||||||
|
size_needed,
|
||||||
|
pack_buf_type,
|
||||||
|
rntm,
|
||||||
|
cntl,
|
||||||
|
thread
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
void* bli_packm_alloc_ex
|
||||||
|
(
|
||||||
|
siz_t size_needed,
|
||||||
|
packbuf_t pack_buf_type,
|
||||||
|
rntm_t* rntm,
|
||||||
|
cntl_t* cntl,
|
||||||
|
thrinfo_t* thread
|
||||||
|
)
|
||||||
|
{
|
||||||
// Query the address of the mem_t entry within the control tree node.
|
// Query the address of the mem_t entry within the control tree node.
|
||||||
mem_t* cntl_mem_p = bli_cntl_pack_mem( cntl );
|
mem_t* cntl_mem_p = bli_cntl_pack_mem( cntl );
|
||||||
|
|
||||||
@@ -55,7 +74,7 @@ void* bli_packm_alloc
|
|||||||
siz_t cntl_mem_size = 0;
|
siz_t cntl_mem_size = 0;
|
||||||
|
|
||||||
if ( bli_mem_is_alloc( cntl_mem_p ) )
|
if ( bli_mem_is_alloc( cntl_mem_p ) )
|
||||||
cntl_mem_size = bli_mem_size( cntl_mem_p );
|
cntl_mem_size = bli_mem_size( cntl_mem_p );
|
||||||
|
|
||||||
if ( cntl_mem_size < size_needed )
|
if ( cntl_mem_size < size_needed )
|
||||||
{
|
{
|
||||||
@@ -64,14 +83,15 @@ void* bli_packm_alloc
|
|||||||
// The chief thread releases the existing block associated with
|
// The chief thread releases the existing block associated with
|
||||||
// the mem_t entry in the control tree, and then re-acquires a
|
// the mem_t entry in the control tree, and then re-acquires a
|
||||||
// new block, saving the associated mem_t entry to local_mem_s.
|
// new block, saving the associated mem_t entry to local_mem_s.
|
||||||
if ( bli_mem_is_alloc( cntl_mem_p ) )
|
if ( bli_mem_is_alloc( cntl_mem_p ) )
|
||||||
{
|
{
|
||||||
bli_pba_release
|
bli_pba_release
|
||||||
(
|
(
|
||||||
rntm,
|
rntm,
|
||||||
cntl_mem_p
|
cntl_mem_p
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
bli_pba_acquire_m
|
bli_pba_acquire_m
|
||||||
(
|
(
|
||||||
rntm,
|
rntm,
|
||||||
@@ -89,11 +109,11 @@ void* bli_packm_alloc
|
|||||||
// this thread's control tree node.
|
// this thread's control tree node.
|
||||||
*cntl_mem_p = *local_mem_p;
|
*cntl_mem_p = *local_mem_p;
|
||||||
|
|
||||||
// Barrier so that the master thread doesn't return from the function
|
// Barrier so that the master thread doesn't return from the function
|
||||||
// before we are done reading.
|
// before we are done reading.
|
||||||
bli_thread_barrier( thread );
|
bli_thread_barrier( thread );
|
||||||
}
|
}
|
||||||
|
|
||||||
return bli_mem_buffer( cntl_mem_p );
|
return bli_mem_buffer( cntl_mem_p );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -32,11 +32,20 @@
|
|||||||
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
BLIS_EXPORT_BLIS void* bli_packm_alloc
|
BLIS_EXPORT_BLIS void* bli_packm_alloc
|
||||||
(
|
(
|
||||||
siz_t size_needed,
|
siz_t size_needed,
|
||||||
rntm_t* rntm,
|
rntm_t* rntm,
|
||||||
cntl_t* cntl,
|
cntl_t* cntl,
|
||||||
thrinfo_t* thread
|
thrinfo_t* thread
|
||||||
);
|
);
|
||||||
|
|
||||||
|
BLIS_EXPORT_BLIS void* bli_packm_alloc_ex
|
||||||
|
(
|
||||||
|
siz_t size_needed,
|
||||||
|
packbuf_t pack_buf_type,
|
||||||
|
rntm_t* rntm,
|
||||||
|
cntl_t* cntl,
|
||||||
|
thrinfo_t* thread
|
||||||
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,14 @@ void bli_l3_cntl_create_if
|
|||||||
family == BLIS_GEMMT ||
|
family == BLIS_GEMMT ||
|
||||||
family == BLIS_TRMM )
|
family == BLIS_TRMM )
|
||||||
{
|
{
|
||||||
*cntl_use = bli_gemm_cntl_create( rntm, family, schema_a, schema_b );
|
*cntl_use = bli_gemm_cntl_create
|
||||||
|
(
|
||||||
|
rntm,
|
||||||
|
family,
|
||||||
|
schema_a,
|
||||||
|
schema_b,
|
||||||
|
bli_obj_ker_fn( c )
|
||||||
|
);
|
||||||
}
|
}
|
||||||
else // if ( family == BLIS_TRSM )
|
else // if ( family == BLIS_TRSM )
|
||||||
{
|
{
|
||||||
@@ -66,7 +73,14 @@ void bli_l3_cntl_create_if
|
|||||||
if ( bli_obj_is_triangular( a ) ) side = BLIS_LEFT;
|
if ( bli_obj_is_triangular( a ) ) side = BLIS_LEFT;
|
||||||
else side = BLIS_RIGHT;
|
else side = BLIS_RIGHT;
|
||||||
|
|
||||||
*cntl_use = bli_trsm_cntl_create( rntm, side, schema_a, schema_b );
|
*cntl_use = bli_trsm_cntl_create
|
||||||
|
(
|
||||||
|
rntm,
|
||||||
|
side,
|
||||||
|
schema_a,
|
||||||
|
schema_b,
|
||||||
|
bli_obj_ker_fn( c )
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
|||||||
@@ -47,6 +47,8 @@
|
|||||||
\
|
\
|
||||||
typedef void (*PASTECH3(ch,opname,_ukr,tsuf)) \
|
typedef void (*PASTECH3(ch,opname,_ukr,tsuf)) \
|
||||||
( \
|
( \
|
||||||
|
dim_t m, \
|
||||||
|
dim_t n, \
|
||||||
dim_t k, \
|
dim_t k, \
|
||||||
ctype* restrict alpha, \
|
ctype* restrict alpha, \
|
||||||
ctype* restrict a, \
|
ctype* restrict a, \
|
||||||
|
|||||||
@@ -51,6 +51,8 @@ void PASTEMAC0(opname) \
|
|||||||
\
|
\
|
||||||
num_t dt = bli_obj_dt( c ); \
|
num_t dt = bli_obj_dt( c ); \
|
||||||
\
|
\
|
||||||
|
dim_t m = bli_obj_length( c ); \
|
||||||
|
dim_t n = bli_obj_width( c ); \
|
||||||
dim_t k = bli_obj_width( a ); \
|
dim_t k = bli_obj_width( a ); \
|
||||||
void* buf_a = bli_obj_buffer_at_off( a ); \
|
void* buf_a = bli_obj_buffer_at_off( a ); \
|
||||||
void* buf_b = bli_obj_buffer_at_off( b ); \
|
void* buf_b = bli_obj_buffer_at_off( b ); \
|
||||||
@@ -75,6 +77,8 @@ void PASTEMAC0(opname) \
|
|||||||
\
|
\
|
||||||
f \
|
f \
|
||||||
( \
|
( \
|
||||||
|
m, \
|
||||||
|
n, \
|
||||||
k, \
|
k, \
|
||||||
buf_alpha, \
|
buf_alpha, \
|
||||||
buf_a, \
|
buf_a, \
|
||||||
|
|||||||
@@ -42,6 +42,8 @@
|
|||||||
\
|
\
|
||||||
void PASTEMAC(ch,opname) \
|
void PASTEMAC(ch,opname) \
|
||||||
( \
|
( \
|
||||||
|
dim_t m, \
|
||||||
|
dim_t n, \
|
||||||
dim_t k, \
|
dim_t k, \
|
||||||
ctype_out* restrict alpha, \
|
ctype_out* restrict alpha, \
|
||||||
ctype_in* restrict a, \
|
ctype_in* restrict a, \
|
||||||
|
|||||||
@@ -39,6 +39,8 @@
|
|||||||
\
|
\
|
||||||
void PASTEMAC(ch,opname) \
|
void PASTEMAC(ch,opname) \
|
||||||
( \
|
( \
|
||||||
|
dim_t m, \
|
||||||
|
dim_t n, \
|
||||||
dim_t k, \
|
dim_t k, \
|
||||||
ctype* restrict alpha, \
|
ctype* restrict alpha, \
|
||||||
ctype* restrict a, \
|
ctype* restrict a, \
|
||||||
@@ -58,16 +60,19 @@ void PASTEMAC(ch,opname) \
|
|||||||
PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \
|
PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \
|
||||||
\
|
\
|
||||||
/* Invoke the typed function for the given datatype. */ \
|
/* Invoke the typed function for the given datatype. */ \
|
||||||
f( \
|
f \
|
||||||
k, \
|
( \
|
||||||
alpha, \
|
m, \
|
||||||
a, \
|
n, \
|
||||||
b, \
|
k, \
|
||||||
beta, \
|
alpha, \
|
||||||
c, rs_c, cs_c, \
|
a, \
|
||||||
data, \
|
b, \
|
||||||
cntx \
|
beta, \
|
||||||
); \
|
c, rs_c, cs_c, \
|
||||||
|
data, \
|
||||||
|
cntx \
|
||||||
|
); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
INSERT_GENTFUNC_BASIC2( gemm_ukernel, gemm, BLIS_GEMM_UKR )
|
INSERT_GENTFUNC_BASIC2( gemm_ukernel, gemm, BLIS_GEMM_UKR )
|
||||||
@@ -98,17 +103,18 @@ void PASTEMAC(ch,opname) \
|
|||||||
PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \
|
PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \
|
||||||
\
|
\
|
||||||
/* Invoke the typed function for the given datatype. */ \
|
/* Invoke the typed function for the given datatype. */ \
|
||||||
f( \
|
f \
|
||||||
k, \
|
( \
|
||||||
alpha, \
|
k, \
|
||||||
a1x, \
|
alpha, \
|
||||||
a11, \
|
a1x, \
|
||||||
bx1, \
|
a11, \
|
||||||
b11, \
|
bx1, \
|
||||||
c11, rs_c, cs_c, \
|
b11, \
|
||||||
data, \
|
c11, rs_c, cs_c, \
|
||||||
cntx \
|
data, \
|
||||||
); \
|
cntx \
|
||||||
|
); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
INSERT_GENTFUNC_BASIC2( gemmtrsm_l_ukernel, gemmtrsm, BLIS_GEMMTRSM_L_UKR )
|
INSERT_GENTFUNC_BASIC2( gemmtrsm_l_ukernel, gemmtrsm, BLIS_GEMMTRSM_L_UKR )
|
||||||
@@ -136,13 +142,14 @@ void PASTEMAC(ch,opname) \
|
|||||||
PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \
|
PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \
|
||||||
\
|
\
|
||||||
/* Invoke the typed function for the given datatype. */ \
|
/* Invoke the typed function for the given datatype. */ \
|
||||||
f( \
|
f \
|
||||||
a, \
|
( \
|
||||||
b, \
|
a, \
|
||||||
c, rs_c, cs_c, \
|
b, \
|
||||||
data, \
|
c, rs_c, cs_c, \
|
||||||
cntx \
|
data, \
|
||||||
); \
|
cntx \
|
||||||
|
); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
INSERT_GENTFUNC_BASIC2( trsm_l_ukernel, trsm, BLIS_TRSM_L_UKR )
|
INSERT_GENTFUNC_BASIC2( trsm_l_ukernel, trsm, BLIS_TRSM_L_UKR )
|
||||||
|
|||||||
@@ -40,10 +40,11 @@ cntl_t* bli_gemm_cntl_create
|
|||||||
rntm_t* rntm,
|
rntm_t* rntm,
|
||||||
opid_t family,
|
opid_t family,
|
||||||
pack_t schema_a,
|
pack_t schema_a,
|
||||||
pack_t schema_b
|
pack_t schema_b,
|
||||||
|
void_fp ker
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
return bli_gemmbp_cntl_create( rntm, family, schema_a, schema_b );
|
return bli_gemmbp_cntl_create( rntm, family, schema_a, schema_b, ker );
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
@@ -53,18 +54,22 @@ cntl_t* bli_gemmbp_cntl_create
|
|||||||
rntm_t* rntm,
|
rntm_t* rntm,
|
||||||
opid_t family,
|
opid_t family,
|
||||||
pack_t schema_a,
|
pack_t schema_a,
|
||||||
pack_t schema_b
|
pack_t schema_b,
|
||||||
|
void_fp ker
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
void_fp macro_kernel_fp;
|
void_fp macro_kernel_fp;
|
||||||
|
|
||||||
// Use the function pointers to the macrokernels that use slab
|
// Choose the default macrokernel based on the operation family...
|
||||||
// assignment of micropanels to threads in the jr and ir loops.
|
|
||||||
if ( family == BLIS_GEMM ) macro_kernel_fp = bli_gemm_ker_var2;
|
if ( family == BLIS_GEMM ) macro_kernel_fp = bli_gemm_ker_var2;
|
||||||
else if ( family == BLIS_GEMMT ) macro_kernel_fp = bli_gemmt_x_ker_var2;
|
else if ( family == BLIS_GEMMT ) macro_kernel_fp = bli_gemmt_x_ker_var2;
|
||||||
else if ( family == BLIS_TRMM ) macro_kernel_fp = bli_trmm_xx_ker_var2;
|
else if ( family == BLIS_TRMM ) macro_kernel_fp = bli_trmm_xx_ker_var2;
|
||||||
else /* should never execute */ macro_kernel_fp = NULL;
|
else /* should never execute */ macro_kernel_fp = NULL;
|
||||||
|
|
||||||
|
// ...unless a non-NULL kernel function pointer is passed in, in which
|
||||||
|
// case we use that instead.
|
||||||
|
if ( ker ) macro_kernel_fp = ker;
|
||||||
|
|
||||||
// Create two nodes for the macro-kernel.
|
// Create two nodes for the macro-kernel.
|
||||||
cntl_t* gemm_cntl_bu_ke = bli_gemm_cntl_create_node
|
cntl_t* gemm_cntl_bu_ke = bli_gemm_cntl_create_node
|
||||||
(
|
(
|
||||||
|
|||||||
@@ -38,7 +38,8 @@ cntl_t* bli_gemm_cntl_create
|
|||||||
rntm_t* rntm,
|
rntm_t* rntm,
|
||||||
opid_t family,
|
opid_t family,
|
||||||
pack_t schema_a,
|
pack_t schema_a,
|
||||||
pack_t schema_b
|
pack_t schema_b,
|
||||||
|
void_fp ker
|
||||||
);
|
);
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
@@ -48,7 +49,8 @@ cntl_t* bli_gemmbp_cntl_create
|
|||||||
rntm_t* rntm,
|
rntm_t* rntm,
|
||||||
opid_t family,
|
opid_t family,
|
||||||
pack_t schema_a,
|
pack_t schema_a,
|
||||||
pack_t schema_b
|
pack_t schema_b,
|
||||||
|
void_fp ker
|
||||||
);
|
);
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
|
|||||||
@@ -283,90 +283,3 @@ void bli_gemm_front
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
#if 0
|
|
||||||
if ( bli_obj_dt( a ) != bli_obj_dt( b ) ||
|
|
||||||
bli_obj_dt( a ) != bli_obj_dt( c ) ||
|
|
||||||
bli_obj_comp_prec( c ) != bli_obj_prec( c ) )
|
|
||||||
{
|
|
||||||
const bool a_is_real = bli_obj_is_real( a );
|
|
||||||
const bool a_is_comp = bli_obj_is_complex( a );
|
|
||||||
const bool b_is_real = bli_obj_is_real( b );
|
|
||||||
const bool b_is_comp = bli_obj_is_complex( b );
|
|
||||||
const bool c_is_real = bli_obj_is_real( c );
|
|
||||||
const bool c_is_comp = bli_obj_is_complex( c );
|
|
||||||
|
|
||||||
const bool a_is_single = bli_obj_is_single_prec( a );
|
|
||||||
const bool a_is_double = bli_obj_is_double_prec( a );
|
|
||||||
const bool b_is_single = bli_obj_is_single_prec( b );
|
|
||||||
const bool b_is_double = bli_obj_is_double_prec( b );
|
|
||||||
const bool c_is_single = bli_obj_is_single_prec( c );
|
|
||||||
const bool c_is_double = bli_obj_is_double_prec( c );
|
|
||||||
|
|
||||||
const bool comp_single = bli_obj_comp_prec( c ) == BLIS_SINGLE_PREC;
|
|
||||||
const bool comp_double = bli_obj_comp_prec( c ) == BLIS_DOUBLE_PREC;
|
|
||||||
|
|
||||||
const bool mixeddomain = bli_obj_domain( c ) != bli_obj_domain( a ) ||
|
|
||||||
bli_obj_domain( c ) != bli_obj_domain( b );
|
|
||||||
|
|
||||||
( void )a_is_real; ( void )a_is_comp;
|
|
||||||
( void )b_is_real; ( void )b_is_comp;
|
|
||||||
( void )c_is_real; ( void )c_is_comp;
|
|
||||||
( void )a_is_single; ( void )a_is_double;
|
|
||||||
( void )b_is_single; ( void )b_is_double;
|
|
||||||
( void )c_is_single; ( void )c_is_double;
|
|
||||||
( void )comp_single; ( void )comp_double;
|
|
||||||
|
|
||||||
if (
|
|
||||||
//( c_is_comp && a_is_comp && b_is_real ) ||
|
|
||||||
//( c_is_comp && a_is_real && b_is_comp ) ||
|
|
||||||
//( c_is_real && a_is_comp && b_is_comp ) ||
|
|
||||||
//( c_is_comp && a_is_real && b_is_real ) ||
|
|
||||||
//( c_is_real && a_is_comp && b_is_real ) ||
|
|
||||||
//( c_is_real && a_is_real && b_is_comp ) ||
|
|
||||||
//FALSE
|
|
||||||
TRUE
|
|
||||||
)
|
|
||||||
{
|
|
||||||
if (
|
|
||||||
( c_is_single && a_is_single && b_is_single && mixeddomain ) ||
|
|
||||||
( c_is_single && a_is_single && b_is_single && comp_single ) ||
|
|
||||||
( c_is_single && a_is_single && b_is_single && comp_double ) ||
|
|
||||||
( c_is_single && a_is_single && b_is_double ) ||
|
|
||||||
( c_is_single && a_is_double && b_is_single ) ||
|
|
||||||
( c_is_double && a_is_single && b_is_single ) ||
|
|
||||||
( c_is_single && a_is_double && b_is_double ) ||
|
|
||||||
( c_is_double && a_is_single && b_is_double ) ||
|
|
||||||
( c_is_double && a_is_double && b_is_single ) ||
|
|
||||||
( c_is_double && a_is_double && b_is_double && comp_single ) ||
|
|
||||||
( c_is_double && a_is_double && b_is_double && comp_double ) ||
|
|
||||||
( c_is_double && a_is_double && b_is_double && mixeddomain ) ||
|
|
||||||
FALSE
|
|
||||||
)
|
|
||||||
bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl );
|
|
||||||
else
|
|
||||||
bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl );
|
|
||||||
}
|
|
||||||
else
|
|
||||||
bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl );
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
#if 0
|
|
||||||
// If any of the storage datatypes differ, or if the execution precision
|
|
||||||
// differs from the storage precision of C, utilize the mixed datatype
|
|
||||||
// code path.
|
|
||||||
// NOTE: We could check the exec dt against the storage dt of C, but for
|
|
||||||
// now we don't support the caller setting the execution domain
|
|
||||||
// explicitly.
|
|
||||||
if ( bli_obj_dt( a ) != bli_obj_dt( b ) ||
|
|
||||||
bli_obj_dt( a ) != bli_obj_dt( c ) ||
|
|
||||||
bli_obj_comp_prec( c ) != bli_obj_prec( c ) )
|
|
||||||
{
|
|
||||||
bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl );
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
#endif
|
|
||||||
|
|
||||||
|
|||||||
@@ -35,28 +35,44 @@
|
|||||||
|
|
||||||
#include "blis.h"
|
#include "blis.h"
|
||||||
|
|
||||||
#define FUNCPTR_T gemm_fp
|
typedef void (*xpbys_mxn_vft)
|
||||||
|
(
|
||||||
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
void* x, inc_t rs_x, inc_t cs_x,
|
||||||
|
void* b,
|
||||||
|
void* y, inc_t rs_y, inc_t cs_y
|
||||||
|
);
|
||||||
|
|
||||||
typedef void (*FUNCPTR_T)
|
#undef GENTFUNC2
|
||||||
(
|
#define GENTFUNC2(ctypex,ctypey,chx,chy,op) \
|
||||||
pack_t schema_a,
|
\
|
||||||
pack_t schema_b,
|
void PASTEMAC2(chx,chy,op) \
|
||||||
dim_t m,
|
( \
|
||||||
dim_t n,
|
dim_t m, \
|
||||||
dim_t k,
|
dim_t n, \
|
||||||
void* alpha,
|
void* x, inc_t rs_x, inc_t cs_x, \
|
||||||
void* a, inc_t cs_a, inc_t is_a,
|
void* b, \
|
||||||
dim_t pd_a, inc_t ps_a,
|
void* y, inc_t rs_y, inc_t cs_y \
|
||||||
void* b, inc_t rs_b, inc_t is_b,
|
) \
|
||||||
dim_t pd_b, inc_t ps_b,
|
{ \
|
||||||
void* beta,
|
ctypex* restrict x_cast = x; \
|
||||||
void* c, inc_t rs_c, inc_t cs_c,
|
ctypey* restrict b_cast = b; \
|
||||||
cntx_t* cntx,
|
ctypey* restrict y_cast = y; \
|
||||||
rntm_t* rntm,
|
\
|
||||||
thrinfo_t* thread
|
PASTEMAC3(chx,chy,chy,xpbys_mxn) \
|
||||||
);
|
( \
|
||||||
|
m, n, \
|
||||||
|
x_cast, rs_x, cs_x, \
|
||||||
|
b_cast, \
|
||||||
|
y_cast, rs_y, cs_y \
|
||||||
|
); \
|
||||||
|
}
|
||||||
|
|
||||||
static FUNCPTR_T GENARRAY(ftypes,gemm_ker_var2);
|
INSERT_GENTFUNC2_BASIC0(xbpys_mxn_fn);
|
||||||
|
INSERT_GENTFUNC2_MIXDP0(xbpys_mxn_fn);
|
||||||
|
|
||||||
|
static xpbys_mxn_vft GENARRAY2_ALL(xbpys_mxn, xbpys_mxn_fn);
|
||||||
|
|
||||||
|
|
||||||
void bli_gemm_ker_var2
|
void bli_gemm_ker_var2
|
||||||
@@ -70,23 +86,8 @@ void bli_gemm_ker_var2
|
|||||||
thrinfo_t* thread
|
thrinfo_t* thread
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
#ifdef BLIS_ENABLE_GEMM_MD
|
|
||||||
// By now, A and B have been packed and cast to the execution precision.
|
|
||||||
// In most cases, such as when storage precision of C differs from the
|
|
||||||
// execution precision, we utilize the mixed datatype code path. However,
|
|
||||||
// a few cases still fall within this kernel, such as mixed domain with
|
|
||||||
// equal precision (ccr, crc, rcc), hence those expressions being disabled
|
|
||||||
// in the conditional below.
|
|
||||||
if ( //( bli_obj_domain( c ) != bli_obj_domain( a ) ) ||
|
|
||||||
//( bli_obj_domain( c ) != bli_obj_domain( b ) ) ||
|
|
||||||
( bli_obj_dt( c ) != bli_obj_exec_dt( c ) ) )
|
|
||||||
{
|
|
||||||
bli_gemm_ker_var2_md( a, b, c, cntx, rntm, cntl, thread );
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
num_t dt_exec = bli_obj_exec_dt( c );
|
num_t dt_exec = bli_obj_exec_dt( c );
|
||||||
|
num_t dt_c = bli_obj_dt( c );
|
||||||
|
|
||||||
pack_t schema_a = bli_obj_pack_schema( a );
|
pack_t schema_a = bli_obj_pack_schema( a );
|
||||||
pack_t schema_b = bli_obj_pack_schema( b );
|
pack_t schema_b = bli_obj_pack_schema( b );
|
||||||
@@ -95,50 +96,55 @@ void bli_gemm_ker_var2
|
|||||||
dim_t n = bli_obj_width( c );
|
dim_t n = bli_obj_width( c );
|
||||||
dim_t k = bli_obj_width( a );
|
dim_t k = bli_obj_width( a );
|
||||||
|
|
||||||
void* buf_a = bli_obj_buffer_at_off( a );
|
char* a_cast = bli_obj_buffer_at_off( a );
|
||||||
inc_t cs_a = bli_obj_col_stride( a );
|
|
||||||
inc_t is_a = bli_obj_imag_stride( a );
|
inc_t is_a = bli_obj_imag_stride( a );
|
||||||
dim_t pd_a = bli_obj_panel_dim( a );
|
dim_t pd_a = bli_obj_panel_dim( a );
|
||||||
inc_t ps_a = bli_obj_panel_stride( a );
|
inc_t ps_a = bli_obj_panel_stride( a );
|
||||||
|
|
||||||
void* buf_b = bli_obj_buffer_at_off( b );
|
char* b_cast = bli_obj_buffer_at_off( b );
|
||||||
inc_t rs_b = bli_obj_row_stride( b );
|
|
||||||
inc_t is_b = bli_obj_imag_stride( b );
|
inc_t is_b = bli_obj_imag_stride( b );
|
||||||
dim_t pd_b = bli_obj_panel_dim( b );
|
dim_t pd_b = bli_obj_panel_dim( b );
|
||||||
inc_t ps_b = bli_obj_panel_stride( b );
|
inc_t ps_b = bli_obj_panel_stride( b );
|
||||||
|
|
||||||
void* buf_c = bli_obj_buffer_at_off( c );
|
char* c_cast = bli_obj_buffer_at_off( c );
|
||||||
inc_t rs_c = bli_obj_row_stride( c );
|
inc_t rs_c = bli_obj_row_stride( c );
|
||||||
inc_t cs_c = bli_obj_col_stride( c );
|
inc_t cs_c = bli_obj_col_stride( c );
|
||||||
|
|
||||||
obj_t scalar_a;
|
// If any dimension is zero, return immediately.
|
||||||
obj_t scalar_b;
|
if ( bli_zero_dim3( m, n, k ) ) return;
|
||||||
|
|
||||||
void* buf_alpha;
|
|
||||||
void* buf_beta;
|
|
||||||
|
|
||||||
FUNCPTR_T f;
|
|
||||||
|
|
||||||
// Detach and multiply the scalars attached to A and B.
|
// Detach and multiply the scalars attached to A and B.
|
||||||
|
// NOTE: We know that the internal scalars of A and B are already of the
|
||||||
|
// target datatypes because the necessary typecasting would have already
|
||||||
|
// taken place during bli_packm_init().
|
||||||
|
obj_t scalar_a;
|
||||||
|
obj_t scalar_b;
|
||||||
bli_obj_scalar_detach( a, &scalar_a );
|
bli_obj_scalar_detach( a, &scalar_a );
|
||||||
bli_obj_scalar_detach( b, &scalar_b );
|
bli_obj_scalar_detach( b, &scalar_b );
|
||||||
bli_mulsc( &scalar_a, &scalar_b );
|
bli_mulsc( &scalar_a, &scalar_b );
|
||||||
|
|
||||||
// Grab the addresses of the internal scalar buffers for the scalar
|
// Grab the addresses of the internal scalar buffers for the scalar
|
||||||
// merged above and the scalar attached to C.
|
// merged above and the scalar attached to C.
|
||||||
buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b );
|
// NOTE: We know that scalar_b is of type dt_exec due to the above code
|
||||||
buf_beta = bli_obj_internal_scalar_buffer( c );
|
// that casts the scalars of A and B to dt_exec via scalar_a and scalar_b,
|
||||||
|
// and we know that the internal scalar in C is already of the type dt_c
|
||||||
|
// due to the casting in the implementation of bli_obj_scalar_attach().
|
||||||
|
char* alpha_cast = bli_obj_internal_scalar_buffer( &scalar_b );
|
||||||
|
char* beta_cast = bli_obj_internal_scalar_buffer( c );
|
||||||
|
|
||||||
// If 1m is being employed on a column- or row-stored matrix with a
|
// If 1m is being employed on a column- or row-stored matrix with a
|
||||||
// real-valued beta, we can use the real domain macro-kernel, which
|
// real-valued beta, we can use the real domain macro-kernel, which
|
||||||
// eliminates a little overhead associated with the 1m virtual
|
// eliminates a little overhead associated with the 1m virtual
|
||||||
// micro-kernel.
|
// micro-kernel.
|
||||||
|
// Only employ this optimization if the storage datatype of C is
|
||||||
|
// equal to the execution/computation datatype.
|
||||||
#if 1
|
#if 1
|
||||||
if ( bli_cntx_method( cntx ) == BLIS_1M )
|
if ( bli_cntx_method( cntx ) == BLIS_1M )
|
||||||
{
|
{
|
||||||
bli_gemm_ind_recast_1m_params
|
bli_gemm_ind_recast_1m_params
|
||||||
(
|
(
|
||||||
&dt_exec,
|
&dt_exec,
|
||||||
|
&dt_c,
|
||||||
schema_a,
|
schema_a,
|
||||||
c,
|
c,
|
||||||
&m, &n, &k,
|
&m, &n, &k,
|
||||||
@@ -151,273 +157,211 @@ void bli_gemm_ker_var2
|
|||||||
|
|
||||||
#ifdef BLIS_ENABLE_GEMM_MD
|
#ifdef BLIS_ENABLE_GEMM_MD
|
||||||
// Tweak parameters in select mixed domain cases (rcc, crc, ccr).
|
// Tweak parameters in select mixed domain cases (rcc, crc, ccr).
|
||||||
bli_gemm_md_ker_var2_recast
|
if ( bli_cntx_method( cntx ) == BLIS_NAT )
|
||||||
(
|
{
|
||||||
&dt_exec,
|
bli_gemm_md_ker_var2_recast
|
||||||
bli_obj_dt( a ),
|
(
|
||||||
bli_obj_dt( b ),
|
&dt_exec,
|
||||||
bli_obj_dt( c ),
|
bli_obj_dt( a ),
|
||||||
&m, &n, &k,
|
bli_obj_dt( b ),
|
||||||
&pd_a, &ps_a,
|
&dt_c,
|
||||||
&pd_b, &ps_b,
|
&m, &n, &k,
|
||||||
c,
|
&pd_a, &ps_a,
|
||||||
&rs_c, &cs_c
|
&pd_b, &ps_b,
|
||||||
);
|
c,
|
||||||
|
&rs_c, &cs_c
|
||||||
|
);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Index into the type combination array to extract the correct
|
siz_t dt_size = bli_dt_size( dt_exec );
|
||||||
// function pointer.
|
siz_t dt_c_size = bli_dt_size( dt_c );
|
||||||
f = ftypes[dt_exec];
|
|
||||||
|
|
||||||
// Invoke the function.
|
// Alias some constants to simpler names.
|
||||||
f( schema_a,
|
const dim_t MR = pd_a;
|
||||||
schema_b,
|
const dim_t NR = pd_b;
|
||||||
m,
|
//const dim_t PACKMR = cs_a;
|
||||||
n,
|
//const dim_t PACKNR = rs_b;
|
||||||
k,
|
|
||||||
buf_alpha,
|
|
||||||
buf_a, cs_a, is_a,
|
|
||||||
pd_a, ps_a,
|
|
||||||
buf_b, rs_b, is_b,
|
|
||||||
pd_b, ps_b,
|
|
||||||
buf_beta,
|
|
||||||
buf_c, rs_c, cs_c,
|
|
||||||
cntx,
|
|
||||||
rntm,
|
|
||||||
thread );
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// Query the context for the micro-kernel address and cast it to its
|
||||||
|
// function pointer type.
|
||||||
|
gemm_ukr_vft gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt_exec, BLIS_GEMM_UKR, cntx );
|
||||||
|
|
||||||
|
// Query the params field from the obj_t. If it is non-NULL, grab the ukr
|
||||||
|
// field of the params struct. If that function pointer is non-NULL, use it
|
||||||
|
// as our microkernel instead of the default microkernel queried from the
|
||||||
|
// cntx above.
|
||||||
|
gemm_ker_params_t* params = bli_obj_ker_params( c );
|
||||||
|
gemm_ukr_vft user_ukr = params ? params->ukr : NULL;
|
||||||
|
if ( user_ukr ) gemm_ukr = user_ukr;
|
||||||
|
|
||||||
|
// Temporary C buffer for edge cases. Note that the strides of this
|
||||||
|
// temporary buffer are set so that they match the storage of the
|
||||||
|
// original C matrix. For example, if C is column-stored, ct will be
|
||||||
|
// column-stored as well.
|
||||||
|
char ct[ BLIS_STACK_BUF_MAX_SIZE ]
|
||||||
|
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE)));
|
||||||
|
const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt_exec, BLIS_GEMM_UKR, cntx );
|
||||||
|
const inc_t rs_ct = ( col_pref ? 1 : NR );
|
||||||
|
const inc_t cs_ct = ( col_pref ? MR : 1 );
|
||||||
|
char* zero = bli_obj_buffer_for_const( dt_exec, &BLIS_ZERO );
|
||||||
|
|
||||||
|
//
|
||||||
|
// Assumptions/assertions:
|
||||||
|
// rs_a == 1
|
||||||
|
// cs_a == PACKMR
|
||||||
|
// pd_a == MR
|
||||||
|
// ps_a == stride to next micro-panel of A
|
||||||
|
// rs_b == PACKNR
|
||||||
|
// cs_b == 1
|
||||||
|
// pd_b == NR
|
||||||
|
// ps_b == stride to next micro-panel of B
|
||||||
|
// rs_c == (no assumptions)
|
||||||
|
// cs_c == (no assumptions)
|
||||||
|
//
|
||||||
|
|
||||||
|
// Compute number of primary and leftover components of the m and n
|
||||||
|
// dimensions.
|
||||||
|
dim_t n_iter = n / NR;
|
||||||
|
dim_t n_left = n % NR;
|
||||||
|
|
||||||
|
dim_t m_iter = m / MR;
|
||||||
|
dim_t m_left = m % MR;
|
||||||
|
|
||||||
|
if ( n_left ) ++n_iter;
|
||||||
|
if ( m_left ) ++m_iter;
|
||||||
|
|
||||||
|
// Determine some increments used to step through A, B, and C.
|
||||||
|
inc_t rstep_a = ps_a * dt_size;
|
||||||
|
|
||||||
|
inc_t cstep_b = ps_b * dt_size;
|
||||||
|
|
||||||
|
inc_t rstep_c = rs_c * MR * dt_c_size;
|
||||||
|
inc_t cstep_c = cs_c * NR * dt_c_size;
|
||||||
|
|
||||||
|
auxinfo_t aux;
|
||||||
|
|
||||||
|
// Save the pack schemas of A and B to the auxinfo_t object.
|
||||||
|
bli_auxinfo_set_schema_a( schema_a, &aux );
|
||||||
|
bli_auxinfo_set_schema_b( schema_b, &aux );
|
||||||
|
|
||||||
|
// Save the imaginary stride of A and B to the auxinfo_t object.
|
||||||
|
bli_auxinfo_set_is_a( is_a, &aux );
|
||||||
|
bli_auxinfo_set_is_b( is_b, &aux );
|
||||||
|
|
||||||
|
// Save the virtual microkernel address and the params.
|
||||||
|
bli_auxinfo_set_ukr( gemm_ukr, &aux );
|
||||||
|
bli_auxinfo_set_params( params, &aux );
|
||||||
|
|
||||||
|
// The 'thread' argument points to the thrinfo_t node for the 2nd (jr)
|
||||||
|
// loop around the microkernel. Here we query the thrinfo_t node for the
|
||||||
|
// 1st (ir) loop around the microkernel.
|
||||||
|
thrinfo_t* caucus = bli_thrinfo_sub_node( thread );
|
||||||
|
|
||||||
|
// Query the number of threads and thread ids for each loop.
|
||||||
|
dim_t jr_nt = bli_thread_n_way( thread );
|
||||||
|
dim_t jr_tid = bli_thread_work_id( thread );
|
||||||
|
dim_t ir_nt = bli_thread_n_way( caucus );
|
||||||
|
dim_t ir_tid = bli_thread_work_id( caucus );
|
||||||
|
|
||||||
|
dim_t jr_start, jr_end;
|
||||||
|
dim_t ir_start, ir_end;
|
||||||
|
dim_t jr_inc, ir_inc;
|
||||||
|
|
||||||
|
// Determine the thread range and increment for the 2nd and 1st loops.
|
||||||
|
// NOTE: The definition of bli_thread_range_jrir() will depend on whether
|
||||||
|
// slab or round-robin partitioning was requested at configure-time.
|
||||||
|
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc );
|
||||||
|
bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc );
|
||||||
|
|
||||||
|
// Loop over the n dimension (NR columns at a time).
|
||||||
|
for ( dim_t j = jr_start; j < jr_end; j += jr_inc )
|
||||||
|
{
|
||||||
|
char* b1 = b_cast + j * cstep_b;
|
||||||
|
char* c1 = c_cast + j * cstep_c;
|
||||||
|
|
||||||
|
dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left );
|
||||||
|
|
||||||
|
// Initialize our next panel of B to be the current panel of B.
|
||||||
|
char* b2 = b1;
|
||||||
|
|
||||||
|
// Loop over the m dimension (MR rows at a time).
|
||||||
|
for ( dim_t i = ir_start; i < ir_end; i += ir_inc )
|
||||||
|
{
|
||||||
|
char* a1 = a_cast + i * rstep_a;
|
||||||
|
char* c11 = c1 + i * rstep_c;
|
||||||
|
|
||||||
|
dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left );
|
||||||
|
|
||||||
|
// Compute the addresses of the next panels of A and B.
|
||||||
|
char* a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc );
|
||||||
|
if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) )
|
||||||
|
{
|
||||||
|
a2 = a_cast;
|
||||||
|
b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc );
|
||||||
|
if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) )
|
||||||
|
b2 = b_cast;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save addresses of next panels of A and B to the auxinfo_t
|
||||||
|
// object.
|
||||||
|
bli_auxinfo_set_next_a( a2, &aux );
|
||||||
|
bli_auxinfo_set_next_b( b2, &aux );
|
||||||
|
|
||||||
|
// Edge case handling now occurs within the microkernel itself, but
|
||||||
|
// we must still explicitly accumulate to a temporary microtile in
|
||||||
|
// situations where a virtual microkernel is being used, such as
|
||||||
|
// during the 1m method or some cases of mixed datatypes.
|
||||||
|
if ( dt_exec == dt_c )
|
||||||
|
{
|
||||||
|
// Invoke the gemm micro-kernel.
|
||||||
|
gemm_ukr
|
||||||
|
(
|
||||||
|
m_cur,
|
||||||
|
n_cur,
|
||||||
|
k,
|
||||||
|
alpha_cast,
|
||||||
|
a1,
|
||||||
|
b1,
|
||||||
|
beta_cast,
|
||||||
|
c11, rs_c, cs_c,
|
||||||
|
&aux,
|
||||||
|
cntx
|
||||||
|
);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// Invoke the gemm micro-kernel.
|
||||||
|
gemm_ukr
|
||||||
|
(
|
||||||
|
MR,
|
||||||
|
NR,
|
||||||
|
k,
|
||||||
|
alpha_cast,
|
||||||
|
a1,
|
||||||
|
b1,
|
||||||
|
zero,
|
||||||
|
&ct, rs_ct, cs_ct,
|
||||||
|
&aux,
|
||||||
|
cntx
|
||||||
|
);
|
||||||
|
|
||||||
|
// Accumulate to C with type-casting.
|
||||||
|
xbpys_mxn[ dt_exec ][ dt_c ]
|
||||||
|
(
|
||||||
|
m_cur, n_cur,
|
||||||
|
&ct, rs_ct, cs_ct,
|
||||||
|
beta_cast,
|
||||||
|
c11, rs_c, cs_c
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#undef GENTFUNC
|
|
||||||
#define GENTFUNC( ctype, ch, varname ) \
|
|
||||||
\
|
|
||||||
void PASTEMAC(ch,varname) \
|
|
||||||
( \
|
|
||||||
pack_t schema_a, \
|
|
||||||
pack_t schema_b, \
|
|
||||||
dim_t m, \
|
|
||||||
dim_t n, \
|
|
||||||
dim_t k, \
|
|
||||||
void* alpha, \
|
|
||||||
void* a, inc_t cs_a, inc_t is_a, \
|
|
||||||
dim_t pd_a, inc_t ps_a, \
|
|
||||||
void* b, inc_t rs_b, inc_t is_b, \
|
|
||||||
dim_t pd_b, inc_t ps_b, \
|
|
||||||
void* beta, \
|
|
||||||
void* c, inc_t rs_c, inc_t cs_c, \
|
|
||||||
cntx_t* cntx, \
|
|
||||||
rntm_t* rntm, \
|
|
||||||
thrinfo_t* thread \
|
|
||||||
) \
|
|
||||||
{ \
|
|
||||||
const num_t dt = PASTEMAC(ch,type); \
|
|
||||||
\
|
|
||||||
/* Alias some constants to simpler names. */ \
|
|
||||||
const dim_t MR = pd_a; \
|
|
||||||
const dim_t NR = pd_b; \
|
|
||||||
/*const dim_t PACKMR = cs_a;*/ \
|
|
||||||
/*const dim_t PACKNR = rs_b;*/ \
|
|
||||||
\
|
|
||||||
/* Query the context for the micro-kernel address and cast it to its
|
|
||||||
function pointer type. Note that the virtual gemm ukernel is queried
|
|
||||||
instead of the native gemm ukernel. This is needed for certain
|
|
||||||
situations for the 1m method that require an extra layer of logic
|
|
||||||
to allow for handling (for example) complex values of beta. Also
|
|
||||||
note that under certain circumstances, the real-domain version of
|
|
||||||
this macrokernel will be called for 1m (NOT the complex version)
|
|
||||||
as an optimization. In these cases, the corresponding real-domain
|
|
||||||
slots within the cntx_t's virtual gemm ukernel func_t will contain
|
|
||||||
pointers to the *native* gemm ukernel, thanks to logic in the
|
|
||||||
context initialization function for the induced method (defined
|
|
||||||
in bli_cntx_ref.c). */ \
|
|
||||||
PASTECH(ch,gemm_ukr_ft) \
|
|
||||||
gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
|
||||||
\
|
|
||||||
/* Temporary C buffer for edge cases. Note that the strides of this
|
|
||||||
temporary buffer are set so that they match the storage of the
|
|
||||||
original C matrix. For example, if C is column-stored, ct will be
|
|
||||||
column-stored as well. */ \
|
|
||||||
ctype ct[ BLIS_STACK_BUF_MAX_SIZE \
|
|
||||||
/ sizeof( ctype ) ] \
|
|
||||||
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
|
|
||||||
const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
|
||||||
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
|
||||||
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
|
||||||
\
|
|
||||||
ctype* restrict zero = PASTEMAC(ch,0); \
|
|
||||||
ctype* restrict a_cast = a; \
|
|
||||||
ctype* restrict b_cast = b; \
|
|
||||||
ctype* restrict c_cast = c; \
|
|
||||||
ctype* restrict alpha_cast = alpha; \
|
|
||||||
ctype* restrict beta_cast = beta; \
|
|
||||||
ctype* restrict b1; \
|
|
||||||
ctype* restrict c1; \
|
|
||||||
\
|
|
||||||
dim_t m_iter, m_left; \
|
|
||||||
dim_t n_iter, n_left; \
|
|
||||||
dim_t i, j; \
|
|
||||||
dim_t m_cur; \
|
|
||||||
dim_t n_cur; \
|
|
||||||
inc_t rstep_a; \
|
|
||||||
inc_t cstep_b; \
|
|
||||||
inc_t rstep_c, cstep_c; \
|
|
||||||
auxinfo_t aux; \
|
|
||||||
\
|
|
||||||
/*
|
|
||||||
Assumptions/assertions:
|
|
||||||
rs_a == 1
|
|
||||||
cs_a == PACKMR
|
|
||||||
pd_a == MR
|
|
||||||
ps_a == stride to next micro-panel of A
|
|
||||||
rs_b == PACKNR
|
|
||||||
cs_b == 1
|
|
||||||
pd_b == NR
|
|
||||||
ps_b == stride to next micro-panel of B
|
|
||||||
rs_c == (no assumptions)
|
|
||||||
cs_c == (no assumptions)
|
|
||||||
*/ \
|
|
||||||
\
|
|
||||||
/* If any dimension is zero, return immediately. */ \
|
|
||||||
if ( bli_zero_dim3( m, n, k ) ) return; \
|
|
||||||
\
|
|
||||||
/* Clear the temporary C buffer in case it has any infs or NaNs. */ \
|
|
||||||
PASTEMAC(ch,set0s_mxn)( MR, NR, \
|
|
||||||
ct, rs_ct, cs_ct ); \
|
|
||||||
\
|
|
||||||
/* Compute number of primary and leftover components of the m and n
|
|
||||||
dimensions. */ \
|
|
||||||
n_iter = n / NR; \
|
|
||||||
n_left = n % NR; \
|
|
||||||
\
|
|
||||||
m_iter = m / MR; \
|
|
||||||
m_left = m % MR; \
|
|
||||||
\
|
|
||||||
if ( n_left ) ++n_iter; \
|
|
||||||
if ( m_left ) ++m_iter; \
|
|
||||||
\
|
|
||||||
/* Determine some increments used to step through A, B, and C. */ \
|
|
||||||
rstep_a = ps_a; \
|
|
||||||
\
|
|
||||||
cstep_b = ps_b; \
|
|
||||||
\
|
|
||||||
rstep_c = rs_c * MR; \
|
|
||||||
cstep_c = cs_c * NR; \
|
|
||||||
\
|
|
||||||
/* Save the pack schemas of A and B to the auxinfo_t object. */ \
|
|
||||||
bli_auxinfo_set_schema_a( schema_a, &aux ); \
|
|
||||||
bli_auxinfo_set_schema_b( schema_b, &aux ); \
|
|
||||||
\
|
|
||||||
/* Save the imaginary stride of A and B to the auxinfo_t object. */ \
|
|
||||||
bli_auxinfo_set_is_a( is_a, &aux ); \
|
|
||||||
bli_auxinfo_set_is_b( is_b, &aux ); \
|
|
||||||
\
|
|
||||||
/* The 'thread' argument points to the thrinfo_t node for the 2nd (jr)
|
|
||||||
loop around the microkernel. Here we query the thrinfo_t node for the
|
|
||||||
1st (ir) loop around the microkernel. */ \
|
|
||||||
thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \
|
|
||||||
\
|
|
||||||
/* Query the number of threads and thread ids for each loop. */ \
|
|
||||||
dim_t jr_nt = bli_thread_n_way( thread ); \
|
|
||||||
dim_t jr_tid = bli_thread_work_id( thread ); \
|
|
||||||
dim_t ir_nt = bli_thread_n_way( caucus ); \
|
|
||||||
dim_t ir_tid = bli_thread_work_id( caucus ); \
|
|
||||||
\
|
|
||||||
dim_t jr_start, jr_end; \
|
|
||||||
dim_t ir_start, ir_end; \
|
|
||||||
dim_t jr_inc, ir_inc; \
|
|
||||||
\
|
|
||||||
/* Determine the thread range and increment for the 2nd and 1st loops.
|
|
||||||
NOTE: The definition of bli_thread_range_jrir() will depend on whether
|
|
||||||
slab or round-robin partitioning was requested at configure-time. */ \
|
|
||||||
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \
|
|
||||||
bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \
|
|
||||||
\
|
|
||||||
/* Loop over the n dimension (NR columns at a time). */ \
|
|
||||||
for ( j = jr_start; j < jr_end; j += jr_inc ) \
|
|
||||||
{ \
|
|
||||||
ctype* restrict a1; \
|
|
||||||
ctype* restrict c11; \
|
|
||||||
ctype* restrict b2; \
|
|
||||||
\
|
|
||||||
b1 = b_cast + j * cstep_b; \
|
|
||||||
c1 = c_cast + j * cstep_c; \
|
|
||||||
\
|
|
||||||
n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \
|
|
||||||
\
|
|
||||||
/* Initialize our next panel of B to be the current panel of B. */ \
|
|
||||||
b2 = b1; \
|
|
||||||
\
|
|
||||||
/* Loop over the m dimension (MR rows at a time). */ \
|
|
||||||
for ( i = ir_start; i < ir_end; i += ir_inc ) \
|
|
||||||
{ \
|
|
||||||
ctype* restrict a2; \
|
|
||||||
\
|
|
||||||
a1 = a_cast + i * rstep_a; \
|
|
||||||
c11 = c1 + i * rstep_c; \
|
|
||||||
\
|
|
||||||
m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \
|
|
||||||
\
|
|
||||||
/* Compute the addresses of the next panels of A and B. */ \
|
|
||||||
a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); \
|
|
||||||
if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \
|
|
||||||
{ \
|
|
||||||
a2 = a_cast; \
|
|
||||||
b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); \
|
|
||||||
if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \
|
|
||||||
b2 = b_cast; \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
/* Save addresses of next panels of A and B to the auxinfo_t
|
|
||||||
object. */ \
|
|
||||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
|
||||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
|
||||||
\
|
|
||||||
/* Handle interior and edge cases separately. */ \
|
|
||||||
if ( m_cur == MR && n_cur == NR ) \
|
|
||||||
{ \
|
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
|
||||||
gemm_ukr \
|
|
||||||
( \
|
|
||||||
k, \
|
|
||||||
alpha_cast, \
|
|
||||||
a1, \
|
|
||||||
b1, \
|
|
||||||
beta_cast, \
|
|
||||||
c11, rs_c, cs_c, \
|
|
||||||
&aux, \
|
|
||||||
cntx \
|
|
||||||
); \
|
|
||||||
} \
|
|
||||||
else \
|
|
||||||
{ \
|
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
|
||||||
gemm_ukr \
|
|
||||||
( \
|
|
||||||
k, \
|
|
||||||
alpha_cast, \
|
|
||||||
a1, \
|
|
||||||
b1, \
|
|
||||||
zero, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
&aux, \
|
|
||||||
cntx \
|
|
||||||
); \
|
|
||||||
\
|
|
||||||
/* Scale the bottom edge of C and add the result from above. */ \
|
|
||||||
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
beta_cast, \
|
|
||||||
c11, rs_c, cs_c ); \
|
|
||||||
} \
|
|
||||||
} \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
/*
|
/*
|
||||||
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \
|
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" );
|
||||||
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" ); \
|
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" );
|
||||||
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: c after", m_cur, n_cur, c11, rs_c, cs_c, "%4.1f", "" ); \
|
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: c after", m_cur, n_cur, c11, rs_c, cs_c, "%4.1f", "" );
|
||||||
*/ \
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
INSERT_GENTFUNC_BASIC0( gemm_ker_var2 )
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,406 +0,0 @@
|
|||||||
/*
|
|
||||||
|
|
||||||
BLIS
|
|
||||||
An object-based framework for developing high-performance BLAS-like
|
|
||||||
libraries.
|
|
||||||
|
|
||||||
Copyright (C) 2014, The University of Texas at Austin
|
|
||||||
|
|
||||||
Redistribution and use in source and binary forms, with or without
|
|
||||||
modification, are permitted provided that the following conditions are
|
|
||||||
met:
|
|
||||||
- Redistributions of source code must retain the above copyright
|
|
||||||
notice, this list of conditions and the following disclaimer.
|
|
||||||
- Redistributions in binary form must reproduce the above copyright
|
|
||||||
notice, this list of conditions and the following disclaimer in the
|
|
||||||
documentation and/or other materials provided with the distribution.
|
|
||||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
|
||||||
contributors may be used to endorse or promote products derived
|
|
||||||
from this software without specific prior written permission.
|
|
||||||
|
|
||||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
||||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
||||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
||||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
|
||||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
|
||||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
|
||||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
|
||||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
|
||||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
||||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
||||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
||||||
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "blis.h"
|
|
||||||
|
|
||||||
#ifdef BLIS_ENABLE_GEMM_MD
|
|
||||||
|
|
||||||
#define FUNCPTR_T gemm_fp
|
|
||||||
|
|
||||||
typedef void (*FUNCPTR_T)
|
|
||||||
(
|
|
||||||
pack_t schema_a,
|
|
||||||
pack_t schema_b,
|
|
||||||
dim_t m,
|
|
||||||
dim_t n,
|
|
||||||
dim_t k,
|
|
||||||
void* alpha,
|
|
||||||
void* a, inc_t cs_a, inc_t is_a,
|
|
||||||
dim_t pd_a, inc_t ps_a,
|
|
||||||
void* b, inc_t rs_b, inc_t is_b,
|
|
||||||
dim_t pd_b, inc_t ps_b,
|
|
||||||
void* beta,
|
|
||||||
void* c, inc_t rs_c, inc_t cs_c,
|
|
||||||
cntx_t* cntx,
|
|
||||||
rntm_t* rntm,
|
|
||||||
thrinfo_t* thread
|
|
||||||
);
|
|
||||||
|
|
||||||
static FUNCPTR_T GENARRAY2_ALL(ftypes,gemm_ker_var2_md);
|
|
||||||
|
|
||||||
|
|
||||||
void bli_gemm_ker_var2_md
|
|
||||||
(
|
|
||||||
obj_t* a,
|
|
||||||
obj_t* b,
|
|
||||||
obj_t* c,
|
|
||||||
cntx_t* cntx,
|
|
||||||
rntm_t* rntm,
|
|
||||||
cntl_t* cntl,
|
|
||||||
thrinfo_t* thread
|
|
||||||
)
|
|
||||||
{
|
|
||||||
num_t dt_exec = bli_obj_exec_dt( c );
|
|
||||||
num_t dt_c = bli_obj_dt( c );
|
|
||||||
|
|
||||||
pack_t schema_a = bli_obj_pack_schema( a );
|
|
||||||
pack_t schema_b = bli_obj_pack_schema( b );
|
|
||||||
|
|
||||||
dim_t m = bli_obj_length( c );
|
|
||||||
dim_t n = bli_obj_width( c );
|
|
||||||
dim_t k = bli_obj_width( a );
|
|
||||||
|
|
||||||
void* buf_a = bli_obj_buffer_at_off( a );
|
|
||||||
inc_t cs_a = bli_obj_col_stride( a );
|
|
||||||
inc_t is_a = bli_obj_imag_stride( a );
|
|
||||||
dim_t pd_a = bli_obj_panel_dim( a );
|
|
||||||
inc_t ps_a = bli_obj_panel_stride( a );
|
|
||||||
|
|
||||||
void* buf_b = bli_obj_buffer_at_off( b );
|
|
||||||
inc_t rs_b = bli_obj_row_stride( b );
|
|
||||||
inc_t is_b = bli_obj_imag_stride( b );
|
|
||||||
dim_t pd_b = bli_obj_panel_dim( b );
|
|
||||||
inc_t ps_b = bli_obj_panel_stride( b );
|
|
||||||
|
|
||||||
void* buf_c = bli_obj_buffer_at_off( c );
|
|
||||||
inc_t rs_c = bli_obj_row_stride( c );
|
|
||||||
inc_t cs_c = bli_obj_col_stride( c );
|
|
||||||
|
|
||||||
obj_t scalar_a;
|
|
||||||
obj_t scalar_b;
|
|
||||||
|
|
||||||
void* buf_alpha;
|
|
||||||
void* buf_beta;
|
|
||||||
|
|
||||||
FUNCPTR_T f;
|
|
||||||
|
|
||||||
// Detach and multiply the scalars attached to A and B.
|
|
||||||
// NOTE: We know that the internal scalars of A and B are already of the
|
|
||||||
// target datatypes because the necessary typecasting would have already
|
|
||||||
// taken place during bli_packm_init().
|
|
||||||
bli_obj_scalar_detach( a, &scalar_a );
|
|
||||||
bli_obj_scalar_detach( b, &scalar_b );
|
|
||||||
bli_mulsc( &scalar_a, &scalar_b );
|
|
||||||
|
|
||||||
// Grab the addresses of the internal scalar buffers for the scalar
|
|
||||||
// merged above and the scalar attached to C.
|
|
||||||
// NOTE: We know that scalar_b is of type dt_exec due to the above code
|
|
||||||
// that casts the scalars of A and B to dt_exec via scalar_a and scalar_b,
|
|
||||||
// and we know that the internal scalar in C is already of the type dt_c
|
|
||||||
// due to the casting in the implementation of bli_obj_scalar_attach().
|
|
||||||
buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b );
|
|
||||||
buf_beta = bli_obj_internal_scalar_buffer( c );
|
|
||||||
|
|
||||||
#if 0
|
|
||||||
// NOTE: Turns out that this optimization will never be employed since
|
|
||||||
// currently bli_gemm_ker_var2_md() is only called when the storage
|
|
||||||
// datatype of C differs from the execution/computation datatype, and
|
|
||||||
// this optimization would only make sense if they are equal.
|
|
||||||
|
|
||||||
// If 1m is being employed on a column- or row-stored matrix with a
|
|
||||||
// real-valued beta, we can use the real domain macro-kernel, which
|
|
||||||
// eliminates a little overhead associated with the 1m virtual
|
|
||||||
// micro-kernel.
|
|
||||||
if ( bli_cntx_method( cntx ) == BLIS_1M )
|
|
||||||
{
|
|
||||||
// Only employ this optimization if the storage datatype of C is
|
|
||||||
// equal to the execution/computation datatype.
|
|
||||||
if ( dt_c == dt_exec )
|
|
||||||
{
|
|
||||||
bli_gemm_ind_recast_1m_params
|
|
||||||
(
|
|
||||||
&dt_exec,
|
|
||||||
schema_a,
|
|
||||||
c,
|
|
||||||
&m, &n, &k,
|
|
||||||
&pd_a, &ps_a,
|
|
||||||
&pd_b, &ps_b,
|
|
||||||
&rs_c, &cs_c
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Tweak parameters in select mixed domain cases (rcc, crc, ccr).
|
|
||||||
bli_gemm_md_ker_var2_recast
|
|
||||||
(
|
|
||||||
&dt_exec,
|
|
||||||
bli_obj_dt( a ),
|
|
||||||
bli_obj_dt( b ),
|
|
||||||
bli_obj_dt( c ),
|
|
||||||
&m, &n, &k,
|
|
||||||
&pd_a, &ps_a,
|
|
||||||
&pd_b, &ps_b,
|
|
||||||
c,
|
|
||||||
&rs_c, &cs_c
|
|
||||||
);
|
|
||||||
|
|
||||||
// Index into the type combination array to extract the correct
|
|
||||||
// function pointer.
|
|
||||||
f = ftypes[dt_c][dt_exec];
|
|
||||||
|
|
||||||
// Invoke the function.
|
|
||||||
f( schema_a,
|
|
||||||
schema_b,
|
|
||||||
m,
|
|
||||||
n,
|
|
||||||
k,
|
|
||||||
buf_alpha,
|
|
||||||
buf_a, cs_a, is_a,
|
|
||||||
pd_a, ps_a,
|
|
||||||
buf_b, rs_b, is_b,
|
|
||||||
pd_b, ps_b,
|
|
||||||
buf_beta,
|
|
||||||
buf_c, rs_c, cs_c,
|
|
||||||
cntx,
|
|
||||||
rntm,
|
|
||||||
thread );
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
#undef GENTFUNC2
|
|
||||||
#define GENTFUNC2( ctype_c, ctype_e, chc, che, varname ) \
|
|
||||||
\
|
|
||||||
void PASTEMAC2(chc,che,varname) \
|
|
||||||
( \
|
|
||||||
pack_t schema_a, \
|
|
||||||
pack_t schema_b, \
|
|
||||||
dim_t m, \
|
|
||||||
dim_t n, \
|
|
||||||
dim_t k, \
|
|
||||||
void* alpha, \
|
|
||||||
void* a, inc_t cs_a, inc_t is_a, \
|
|
||||||
dim_t pd_a, inc_t ps_a, \
|
|
||||||
void* b, inc_t rs_b, inc_t is_b, \
|
|
||||||
dim_t pd_b, inc_t ps_b, \
|
|
||||||
void* beta, \
|
|
||||||
void* c, inc_t rs_c, inc_t cs_c, \
|
|
||||||
cntx_t* cntx, \
|
|
||||||
rntm_t* rntm, \
|
|
||||||
thrinfo_t* thread \
|
|
||||||
) \
|
|
||||||
{ \
|
|
||||||
const num_t dte = PASTEMAC(che,type); \
|
|
||||||
/*const num_t dtc = PASTEMAC(chc,type);*/ \
|
|
||||||
\
|
|
||||||
/* Alias some constants to simpler names. */ \
|
|
||||||
const dim_t MR = pd_a; \
|
|
||||||
const dim_t NR = pd_b; \
|
|
||||||
/*const dim_t PACKMR = cs_a;*/ \
|
|
||||||
/*const dim_t PACKNR = rs_b;*/ \
|
|
||||||
\
|
|
||||||
/* Query the context for the micro-kernel address and cast it to its
|
|
||||||
function pointer type. */ \
|
|
||||||
PASTECH(che,gemm_ukr_ft) \
|
|
||||||
gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dte, BLIS_GEMM_UKR, cntx ); \
|
|
||||||
\
|
|
||||||
/* Temporary C buffer for edge cases. Note that the strides of this
|
|
||||||
temporary buffer are set so that they match the storage of the
|
|
||||||
original C matrix. For example, if C is column-stored, ct will be
|
|
||||||
column-stored as well. */ \
|
|
||||||
ctype_e ct[ BLIS_STACK_BUF_MAX_SIZE \
|
|
||||||
/ sizeof( ctype_e ) ] \
|
|
||||||
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
|
|
||||||
const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dte, BLIS_GEMM_UKR, cntx ); \
|
|
||||||
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
|
||||||
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
|
||||||
\
|
|
||||||
ctype_e* restrict zero = PASTEMAC(che,0); \
|
|
||||||
ctype_e* restrict a_cast = a; \
|
|
||||||
ctype_e* restrict b_cast = b; \
|
|
||||||
ctype_c* restrict c_cast = c; \
|
|
||||||
ctype_e* restrict alpha_cast = alpha; \
|
|
||||||
ctype_c* restrict beta_cast = beta; \
|
|
||||||
ctype_e* restrict b1; \
|
|
||||||
ctype_c* restrict c1; \
|
|
||||||
\
|
|
||||||
dim_t m_iter, m_left; \
|
|
||||||
dim_t n_iter, n_left; \
|
|
||||||
dim_t i, j; \
|
|
||||||
dim_t m_cur; \
|
|
||||||
dim_t n_cur; \
|
|
||||||
inc_t rstep_a; \
|
|
||||||
inc_t cstep_b; \
|
|
||||||
inc_t rstep_c, cstep_c; \
|
|
||||||
auxinfo_t aux; \
|
|
||||||
\
|
|
||||||
/*
|
|
||||||
Assumptions/assertions:
|
|
||||||
rs_a == 1
|
|
||||||
cs_a == PACKMR
|
|
||||||
pd_a == MR
|
|
||||||
ps_a == stride to next micro-panel of A
|
|
||||||
rs_b == PACKNR
|
|
||||||
cs_b == 1
|
|
||||||
pd_b == NR
|
|
||||||
ps_b == stride to next micro-panel of B
|
|
||||||
rs_c == (no assumptions)
|
|
||||||
cs_c == (no assumptions)
|
|
||||||
*/ \
|
|
||||||
\
|
|
||||||
/* If any dimension is zero, return immediately. */ \
|
|
||||||
if ( bli_zero_dim3( m, n, k ) ) return; \
|
|
||||||
\
|
|
||||||
/* Clear the temporary C buffer in case it has any infs or NaNs. */ \
|
|
||||||
PASTEMAC(che,set0s_mxn)( MR, NR, \
|
|
||||||
ct, rs_ct, cs_ct ); \
|
|
||||||
\
|
|
||||||
/* Compute number of primary and leftover components of the m and n
|
|
||||||
dimensions. */ \
|
|
||||||
n_iter = n / NR; \
|
|
||||||
n_left = n % NR; \
|
|
||||||
\
|
|
||||||
m_iter = m / MR; \
|
|
||||||
m_left = m % MR; \
|
|
||||||
\
|
|
||||||
if ( n_left ) ++n_iter; \
|
|
||||||
if ( m_left ) ++m_iter; \
|
|
||||||
\
|
|
||||||
/* Determine some increments used to step through A, B, and C. */ \
|
|
||||||
rstep_a = ps_a; \
|
|
||||||
\
|
|
||||||
cstep_b = ps_b; \
|
|
||||||
\
|
|
||||||
rstep_c = rs_c * MR; \
|
|
||||||
cstep_c = cs_c * NR; \
|
|
||||||
\
|
|
||||||
/* Save the pack schemas of A and B to the auxinfo_t object. */ \
|
|
||||||
bli_auxinfo_set_schema_a( schema_a, &aux ); \
|
|
||||||
bli_auxinfo_set_schema_b( schema_b, &aux ); \
|
|
||||||
\
|
|
||||||
/* Save the imaginary stride of A and B to the auxinfo_t object. */ \
|
|
||||||
bli_auxinfo_set_is_a( is_a, &aux ); \
|
|
||||||
bli_auxinfo_set_is_b( is_b, &aux ); \
|
|
||||||
\
|
|
||||||
/* The 'thread' argument points to the thrinfo_t node for the 2nd (jr)
|
|
||||||
loop around the microkernel. Here we query the thrinfo_t node for the
|
|
||||||
1st (ir) loop around the microkernel. */ \
|
|
||||||
thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \
|
|
||||||
\
|
|
||||||
/* Query the number of threads and thread ids for each loop. */ \
|
|
||||||
dim_t jr_nt = bli_thread_n_way( thread ); \
|
|
||||||
dim_t jr_tid = bli_thread_work_id( thread ); \
|
|
||||||
dim_t ir_nt = bli_thread_n_way( caucus ); \
|
|
||||||
dim_t ir_tid = bli_thread_work_id( caucus ); \
|
|
||||||
\
|
|
||||||
dim_t jr_start, jr_end; \
|
|
||||||
dim_t ir_start, ir_end; \
|
|
||||||
dim_t jr_inc, ir_inc; \
|
|
||||||
\
|
|
||||||
/* Determine the thread range and increment for the 2nd and 1st loops.
|
|
||||||
NOTE: The definition of bli_thread_range_jrir() will depend on whether
|
|
||||||
slab or round-robin partitioning was requested at configure-time. */ \
|
|
||||||
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \
|
|
||||||
bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \
|
|
||||||
\
|
|
||||||
/* Loop over the n dimension (NR columns at a time). */ \
|
|
||||||
for ( j = jr_start; j < jr_end; j += jr_inc ) \
|
|
||||||
{ \
|
|
||||||
ctype_e* restrict a1; \
|
|
||||||
ctype_c* restrict c11; \
|
|
||||||
ctype_e* restrict b2; \
|
|
||||||
\
|
|
||||||
b1 = b_cast + j * cstep_b; \
|
|
||||||
c1 = c_cast + j * cstep_c; \
|
|
||||||
\
|
|
||||||
n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \
|
|
||||||
\
|
|
||||||
/* Initialize our next panel of B to be the current panel of B. */ \
|
|
||||||
b2 = b1; \
|
|
||||||
\
|
|
||||||
/* Loop over the m dimension (MR rows at a time). */ \
|
|
||||||
for ( i = ir_start; i < ir_end; i += ir_inc ) \
|
|
||||||
{ \
|
|
||||||
ctype_e* restrict a2; \
|
|
||||||
\
|
|
||||||
a1 = a_cast + i * rstep_a; \
|
|
||||||
c11 = c1 + i * rstep_c; \
|
|
||||||
\
|
|
||||||
m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \
|
|
||||||
\
|
|
||||||
/* Compute the addresses of the next panels of A and B. */ \
|
|
||||||
a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); \
|
|
||||||
if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \
|
|
||||||
{ \
|
|
||||||
a2 = a_cast; \
|
|
||||||
b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); \
|
|
||||||
if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \
|
|
||||||
b2 = b_cast; \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
/* Save addresses of next panels of A and B to the auxinfo_t
|
|
||||||
object. */ \
|
|
||||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
|
||||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
|
||||||
\
|
|
||||||
/* Always save the micropanel product to the local microtile and
|
|
||||||
then accumulate it into C via the xpbys_mxn macro. */ \
|
|
||||||
/*if ( 1 )*/ \
|
|
||||||
{ \
|
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
|
||||||
gemm_ukr \
|
|
||||||
( \
|
|
||||||
k, \
|
|
||||||
alpha_cast, \
|
|
||||||
a1, \
|
|
||||||
b1, \
|
|
||||||
zero, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
&aux, \
|
|
||||||
cntx \
|
|
||||||
); \
|
|
||||||
\
|
|
||||||
/* Scale the microtile of C and add the result from above. */ \
|
|
||||||
PASTEMAC3(che,chc,chc,xpbys_mxn) \
|
|
||||||
( \
|
|
||||||
m_cur, n_cur, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
beta_cast, \
|
|
||||||
c11, rs_c, cs_c \
|
|
||||||
); \
|
|
||||||
} \
|
|
||||||
} \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
/*
|
|
||||||
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \
|
|
||||||
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" ); \
|
|
||||||
PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: c after", m_cur, n_cur, c11, rs_c, cs_c, "%4.1f", "" ); \
|
|
||||||
*/ \
|
|
||||||
}
|
|
||||||
|
|
||||||
INSERT_GENTFUNC2_BASIC0( gemm_ker_var2_md )
|
|
||||||
INSERT_GENTFUNC2_MIXDP0( gemm_ker_var2_md )
|
|
||||||
|
|
||||||
#endif
|
|
||||||
@@ -154,7 +154,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast
|
|||||||
num_t* dt_comp,
|
num_t* dt_comp,
|
||||||
num_t dt_a,
|
num_t dt_a,
|
||||||
num_t dt_b,
|
num_t dt_b,
|
||||||
num_t dt_c,
|
num_t* dt_c,
|
||||||
dim_t* m,
|
dim_t* m,
|
||||||
dim_t* n,
|
dim_t* n,
|
||||||
dim_t* k,
|
dim_t* k,
|
||||||
@@ -164,7 +164,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast
|
|||||||
inc_t* rs_c, inc_t* cs_c
|
inc_t* rs_c, inc_t* cs_c
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
if ( bli_is_real( dt_c ) &&
|
if ( bli_is_real( *dt_c ) &&
|
||||||
bli_is_complex( dt_a ) &&
|
bli_is_complex( dt_a ) &&
|
||||||
bli_is_complex( dt_b ) )
|
bli_is_complex( dt_b ) )
|
||||||
{
|
{
|
||||||
@@ -177,7 +177,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast
|
|||||||
*ps_a *= 2;
|
*ps_a *= 2;
|
||||||
*ps_b *= 2;
|
*ps_b *= 2;
|
||||||
}
|
}
|
||||||
else if ( bli_is_complex( dt_c ) &&
|
else if ( bli_is_complex( *dt_c ) &&
|
||||||
bli_is_real( dt_a ) &&
|
bli_is_real( dt_a ) &&
|
||||||
bli_is_complex( dt_b ) )
|
bli_is_complex( dt_b ) )
|
||||||
{
|
{
|
||||||
@@ -197,6 +197,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast
|
|||||||
// to the real virtual microkernel slots of the context) instead of
|
// to the real virtual microkernel slots of the context) instead of
|
||||||
// the complex macrokernel and c2r virtual microkernel.
|
// the complex macrokernel and c2r virtual microkernel.
|
||||||
*dt_comp = bli_dt_proj_to_real( *dt_comp );
|
*dt_comp = bli_dt_proj_to_real( *dt_comp );
|
||||||
|
*dt_c = bli_dt_proj_to_real( *dt_c );
|
||||||
*n *= 2;
|
*n *= 2;
|
||||||
*pd_b *= 2; *ps_b *= 2;
|
*pd_b *= 2; *ps_b *= 2;
|
||||||
*rs_c *= 2;
|
*rs_c *= 2;
|
||||||
@@ -211,7 +212,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast
|
|||||||
*ps_a /= 2;
|
*ps_a /= 2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if ( bli_is_complex( dt_c ) &&
|
else if ( bli_is_complex( *dt_c ) &&
|
||||||
bli_is_complex( dt_a ) &&
|
bli_is_complex( dt_a ) &&
|
||||||
bli_is_real( dt_b ) )
|
bli_is_real( dt_b ) )
|
||||||
{
|
{
|
||||||
@@ -231,6 +232,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast
|
|||||||
// to the real virtual microkernel slots of the context) instead of
|
// to the real virtual microkernel slots of the context) instead of
|
||||||
// the complex macrokernel and c2r virtual microkernel.
|
// the complex macrokernel and c2r virtual microkernel.
|
||||||
*dt_comp = bli_dt_proj_to_real( *dt_comp );
|
*dt_comp = bli_dt_proj_to_real( *dt_comp );
|
||||||
|
*dt_c = bli_dt_proj_to_real( *dt_c );
|
||||||
*m *= 2;
|
*m *= 2;
|
||||||
*pd_a *= 2; *ps_a *= 2;
|
*pd_a *= 2; *ps_a *= 2;
|
||||||
*cs_c *= 2;
|
*cs_c *= 2;
|
||||||
@@ -274,54 +276,3 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
//
|
|
||||||
// Prototype object-based interfaces.
|
|
||||||
//
|
|
||||||
|
|
||||||
#undef GENPROT
|
|
||||||
#define GENPROT( opname ) \
|
|
||||||
\
|
|
||||||
void PASTEMAC0(opname) \
|
|
||||||
( \
|
|
||||||
obj_t* a, \
|
|
||||||
obj_t* b, \
|
|
||||||
obj_t* c, \
|
|
||||||
cntx_t* cntx, \
|
|
||||||
rntm_t* rntm, \
|
|
||||||
cntl_t* cntl, \
|
|
||||||
thrinfo_t* thread \
|
|
||||||
);
|
|
||||||
|
|
||||||
GENPROT( gemm_ker_var2_md )
|
|
||||||
|
|
||||||
//
|
|
||||||
// Prototype BLAS-like interfaces with void pointer operands.
|
|
||||||
//
|
|
||||||
|
|
||||||
#undef GENTPROT2
|
|
||||||
#define GENTPROT2( ctype_c, ctype_e, chc, che, varname ) \
|
|
||||||
\
|
|
||||||
void PASTEMAC2(chc,che,varname) \
|
|
||||||
( \
|
|
||||||
pack_t schema_a, \
|
|
||||||
pack_t schema_b, \
|
|
||||||
dim_t m, \
|
|
||||||
dim_t n, \
|
|
||||||
dim_t k, \
|
|
||||||
void* alpha, \
|
|
||||||
void* a, inc_t cs_a, inc_t is_a, \
|
|
||||||
dim_t pd_a, inc_t ps_a, \
|
|
||||||
void* b, inc_t rs_b, inc_t is_b, \
|
|
||||||
dim_t pd_b, inc_t ps_b, \
|
|
||||||
void* beta, \
|
|
||||||
void* c, inc_t rs_c, inc_t cs_c, \
|
|
||||||
cntx_t* cntx, \
|
|
||||||
rntm_t* rntm, \
|
|
||||||
thrinfo_t* thread \
|
|
||||||
);
|
|
||||||
|
|
||||||
INSERT_GENTPROT2_BASIC0( gemm_ker_var2_md )
|
|
||||||
INSERT_GENTPROT2_MIXDP0( gemm_ker_var2_md )
|
|
||||||
|
|
||||||
|
|||||||
@@ -41,6 +41,8 @@
|
|||||||
\
|
\
|
||||||
void PASTEMAC2(ch,opname,suf) \
|
void PASTEMAC2(ch,opname,suf) \
|
||||||
( \
|
( \
|
||||||
|
dim_t m, \
|
||||||
|
dim_t n, \
|
||||||
dim_t k, \
|
dim_t k, \
|
||||||
ctype* restrict alpha, \
|
ctype* restrict alpha, \
|
||||||
ctype* restrict a, \
|
ctype* restrict a, \
|
||||||
@@ -61,6 +63,9 @@ void PASTEMAC2(ch,opname,suf) \
|
|||||||
\
|
\
|
||||||
const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \
|
const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \
|
||||||
const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \
|
const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \
|
||||||
|
\
|
||||||
|
dim_t mr_r = mr; \
|
||||||
|
dim_t nr_r = nr; \
|
||||||
\
|
\
|
||||||
ctype ct[ BLIS_STACK_BUF_MAX_SIZE \
|
ctype ct[ BLIS_STACK_BUF_MAX_SIZE \
|
||||||
/ sizeof( ctype_r ) ] \
|
/ sizeof( ctype_r ) ] \
|
||||||
@@ -81,6 +86,9 @@ void PASTEMAC2(ch,opname,suf) \
|
|||||||
\
|
\
|
||||||
ctype_r* restrict beta_r = &PASTEMAC(ch,real)( *beta ); \
|
ctype_r* restrict beta_r = &PASTEMAC(ch,real)( *beta ); \
|
||||||
ctype_r* restrict beta_i = &PASTEMAC(ch,imag)( *beta ); \
|
ctype_r* restrict beta_i = &PASTEMAC(ch,imag)( *beta ); \
|
||||||
|
\
|
||||||
|
dim_t m_use; \
|
||||||
|
dim_t n_use; \
|
||||||
\
|
\
|
||||||
ctype_r* c_use; \
|
ctype_r* c_use; \
|
||||||
inc_t rs_c_use; \
|
inc_t rs_c_use; \
|
||||||
@@ -146,17 +154,16 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \
|
|||||||
rs_c_use = rs_ct; \
|
rs_c_use = rs_ct; \
|
||||||
cs_c_use = cs_ct; \
|
cs_c_use = cs_ct; \
|
||||||
\
|
\
|
||||||
/* Convert the strides from being in units of complex elements to
|
/* Convert the strides and corresponding microtile dimension from being
|
||||||
be in units of real elements. Note that we don't need to check for
|
in units of complex elements to be in units of real elements. */ \
|
||||||
general storage here because that case corresponds to the scenario
|
if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) { cs_c_use *= 2; mr_r *= 2; } \
|
||||||
where we are using the ct buffer and its rs_ct/cs_ct strides. */ \
|
else { rs_c_use *= 2; nr_r *= 2; }\
|
||||||
if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) cs_c_use *= 2; \
|
|
||||||
else rs_c_use *= 2; \
|
|
||||||
\
|
|
||||||
\
|
\
|
||||||
/* c = beta * c + alpha_r * a * b; */ \
|
/* c = beta * c + alpha_r * a * b; */ \
|
||||||
rgemm_ukr \
|
rgemm_ukr \
|
||||||
( \
|
( \
|
||||||
|
mr_r, \
|
||||||
|
nr_r, \
|
||||||
k, \
|
k, \
|
||||||
alpha_r, \
|
alpha_r, \
|
||||||
a_r, \
|
a_r, \
|
||||||
@@ -166,14 +173,12 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \
|
|||||||
data, \
|
data, \
|
||||||
cntx \
|
cntx \
|
||||||
); \
|
); \
|
||||||
\
|
|
||||||
dim_t i, j; \
|
|
||||||
\
|
\
|
||||||
/* Accumulate the final result in ct back to c. */ \
|
/* Accumulate the final result in ct back to c. */ \
|
||||||
if ( PASTEMAC(ch,eq1)( *beta ) ) \
|
if ( PASTEMAC(ch,eq1)( *beta ) ) \
|
||||||
{ \
|
{ \
|
||||||
for ( j = 0; j < nr; ++j ) \
|
for ( dim_t j = 0; j < n; ++j ) \
|
||||||
for ( i = 0; i < mr; ++i ) \
|
for ( dim_t i = 0; i < m; ++i ) \
|
||||||
{ \
|
{ \
|
||||||
PASTEMAC(ch,adds)( *(ct + i*rs_ct + j*cs_ct), \
|
PASTEMAC(ch,adds)( *(ct + i*rs_ct + j*cs_ct), \
|
||||||
*(c + i*rs_c + j*cs_c ) ); \
|
*(c + i*rs_c + j*cs_c ) ); \
|
||||||
@@ -181,8 +186,8 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \
|
|||||||
} \
|
} \
|
||||||
else if ( PASTEMAC(ch,eq0)( *beta ) ) \
|
else if ( PASTEMAC(ch,eq0)( *beta ) ) \
|
||||||
{ \
|
{ \
|
||||||
for ( j = 0; j < nr; ++j ) \
|
for ( dim_t j = 0; j < n; ++j ) \
|
||||||
for ( i = 0; i < mr; ++i ) \
|
for ( dim_t i = 0; i < m; ++i ) \
|
||||||
{ \
|
{ \
|
||||||
PASTEMAC(ch,copys)( *(ct + i*rs_ct + j*cs_ct), \
|
PASTEMAC(ch,copys)( *(ct + i*rs_ct + j*cs_ct), \
|
||||||
*(c + i*rs_c + j*cs_c ) ); \
|
*(c + i*rs_c + j*cs_c ) ); \
|
||||||
@@ -190,8 +195,8 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \
|
|||||||
} \
|
} \
|
||||||
else \
|
else \
|
||||||
{ \
|
{ \
|
||||||
for ( j = 0; j < nr; ++j ) \
|
for ( dim_t j = 0; j < n; ++j ) \
|
||||||
for ( i = 0; i < mr; ++i ) \
|
for ( dim_t i = 0; i < m; ++i ) \
|
||||||
{ \
|
{ \
|
||||||
PASTEMAC(ch,xpbys)( *(ct + i*rs_ct + j*cs_ct), \
|
PASTEMAC(ch,xpbys)( *(ct + i*rs_ct + j*cs_ct), \
|
||||||
*beta, \
|
*beta, \
|
||||||
@@ -207,17 +212,19 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \
|
|||||||
c_use = ( ctype_r* )c; \
|
c_use = ( ctype_r* )c; \
|
||||||
rs_c_use = rs_c; \
|
rs_c_use = rs_c; \
|
||||||
cs_c_use = cs_c; \
|
cs_c_use = cs_c; \
|
||||||
|
m_use = m; \
|
||||||
|
n_use = n; \
|
||||||
\
|
\
|
||||||
/* Convert the strides from being in units of complex elements to
|
/* Convert the strides and corresponding microtile dimension from being
|
||||||
be in units of real elements. Note that we don't need to check for
|
in units of complex elements to be in units of real elements. */ \
|
||||||
general storage here because that case corresponds to the scenario
|
if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) { cs_c_use *= 2; m_use *= 2; } \
|
||||||
where we are using the ct buffer and its rs_ct/cs_ct strides. */ \
|
else { rs_c_use *= 2; n_use *= 2; } \
|
||||||
if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) cs_c_use *= 2; \
|
|
||||||
else rs_c_use *= 2; \
|
|
||||||
\
|
\
|
||||||
/* c = beta * c + alpha_r * a * b; */ \
|
/* c = beta * c + alpha_r * a * b; */ \
|
||||||
rgemm_ukr \
|
rgemm_ukr \
|
||||||
( \
|
( \
|
||||||
|
m_use, \
|
||||||
|
n_use, \
|
||||||
k, \
|
k, \
|
||||||
alpha_r, \
|
alpha_r, \
|
||||||
a_r, \
|
a_r, \
|
||||||
|
|||||||
@@ -34,6 +34,16 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
//
|
||||||
|
// gemm kernel parameter struct.
|
||||||
|
//
|
||||||
|
|
||||||
|
typedef struct
|
||||||
|
{
|
||||||
|
gemm_ukr_vft ukr;
|
||||||
|
} gemm_ker_params_t;
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Prototype object-based interfaces.
|
// Prototype object-based interfaces.
|
||||||
//
|
//
|
||||||
@@ -59,32 +69,3 @@ GENPROT( gemm_blk_var3 )
|
|||||||
GENPROT( gemm_ker_var1 )
|
GENPROT( gemm_ker_var1 )
|
||||||
GENPROT( gemm_ker_var2 )
|
GENPROT( gemm_ker_var2 )
|
||||||
|
|
||||||
|
|
||||||
//
|
|
||||||
// Prototype BLAS-like interfaces with void pointer operands.
|
|
||||||
//
|
|
||||||
|
|
||||||
#undef GENTPROT
|
|
||||||
#define GENTPROT( ctype, ch, varname ) \
|
|
||||||
\
|
|
||||||
void PASTEMAC(ch,varname) \
|
|
||||||
( \
|
|
||||||
pack_t schema_a, \
|
|
||||||
pack_t schema_b, \
|
|
||||||
dim_t m, \
|
|
||||||
dim_t n, \
|
|
||||||
dim_t k, \
|
|
||||||
void* alpha, \
|
|
||||||
void* a, inc_t cs_a, inc_t is_a, \
|
|
||||||
dim_t pd_a, inc_t ps_a, \
|
|
||||||
void* b, inc_t rs_b, inc_t is_b, \
|
|
||||||
dim_t pd_b, inc_t ps_b, \
|
|
||||||
void* beta, \
|
|
||||||
void* c, inc_t rs_c, inc_t cs_c, \
|
|
||||||
cntx_t* cntx, \
|
|
||||||
rntm_t* rntm, \
|
|
||||||
thrinfo_t* thread \
|
|
||||||
);
|
|
||||||
|
|
||||||
INSERT_GENTPROT_BASIC0( gemm_ker_var2 )
|
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,7 @@
|
|||||||
BLIS_INLINE void bli_gemm_ind_recast_1m_params
|
BLIS_INLINE void bli_gemm_ind_recast_1m_params
|
||||||
(
|
(
|
||||||
num_t* dt_exec,
|
num_t* dt_exec,
|
||||||
|
num_t* dt_c,
|
||||||
pack_t schema_a,
|
pack_t schema_a,
|
||||||
obj_t* c,
|
obj_t* c,
|
||||||
dim_t* m,
|
dim_t* m,
|
||||||
@@ -57,6 +58,7 @@ BLIS_INLINE void bli_gemm_ind_recast_1m_params
|
|||||||
!bli_is_gen_stored( *rs_c, *cs_c ) )
|
!bli_is_gen_stored( *rs_c, *cs_c ) )
|
||||||
{
|
{
|
||||||
*dt_exec = bli_dt_proj_to_real( *dt_exec );
|
*dt_exec = bli_dt_proj_to_real( *dt_exec );
|
||||||
|
*dt_c = bli_dt_proj_to_real( *dt_c );
|
||||||
|
|
||||||
if ( bli_is_1e_packed( schema_a ) )
|
if ( bli_is_1e_packed( schema_a ) )
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -279,6 +279,9 @@ void PASTEMAC(ch,varname) \
|
|||||||
/* Save the imaginary stride of A and B to the auxinfo_t object. */ \
|
/* Save the imaginary stride of A and B to the auxinfo_t object. */ \
|
||||||
bli_auxinfo_set_is_a( is_a, &aux ); \
|
bli_auxinfo_set_is_a( is_a, &aux ); \
|
||||||
bli_auxinfo_set_is_b( is_b, &aux ); \
|
bli_auxinfo_set_is_b( is_b, &aux ); \
|
||||||
|
\
|
||||||
|
/* Save the desired output datatype (indicating no typecasting). */ \
|
||||||
|
/*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \
|
||||||
\
|
\
|
||||||
/* The 'thread' argument points to the thrinfo_t node for the 2nd (jr)
|
/* The 'thread' argument points to the thrinfo_t node for the 2nd (jr)
|
||||||
loop around the microkernel. Here we query the thrinfo_t node for the
|
loop around the microkernel. Here we query the thrinfo_t node for the
|
||||||
@@ -381,43 +384,20 @@ void PASTEMAC(ch,varname) \
|
|||||||
And if we're strictly above the diagonal, we do nothing and
|
And if we're strictly above the diagonal, we do nothing and
|
||||||
continue. */ \
|
continue. */ \
|
||||||
{ \
|
{ \
|
||||||
/* Handle interior and edge cases separately. */ \
|
/* Invoke the gemm micro-kernel. */ \
|
||||||
if ( m_cur == MR && n_cur == NR ) \
|
gemm_ukr \
|
||||||
{ \
|
( \
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
m_cur, \
|
||||||
gemm_ukr \
|
n_cur, \
|
||||||
( \
|
k, \
|
||||||
k, \
|
alpha_cast, \
|
||||||
alpha_cast, \
|
a1, \
|
||||||
a1, \
|
b1, \
|
||||||
b1, \
|
beta_cast, \
|
||||||
beta_cast, \
|
c11, rs_c, cs_c, \
|
||||||
c11, rs_c, cs_c, \
|
&aux, \
|
||||||
&aux, \
|
cntx \
|
||||||
cntx \
|
); \
|
||||||
); \
|
|
||||||
} \
|
|
||||||
else \
|
|
||||||
{ \
|
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
|
||||||
gemm_ukr \
|
|
||||||
( \
|
|
||||||
k, \
|
|
||||||
alpha_cast, \
|
|
||||||
a1, \
|
|
||||||
b1, \
|
|
||||||
zero, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
&aux, \
|
|
||||||
cntx \
|
|
||||||
); \
|
|
||||||
\
|
|
||||||
/* Scale the edge of C and add the result. */ \
|
|
||||||
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
beta_cast, \
|
|
||||||
c11, rs_c, cs_c ); \
|
|
||||||
} \
|
|
||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
@@ -490,6 +470,8 @@ void PASTEMAC(ch,varname) \
|
|||||||
/* Invoke the gemm micro-kernel. */ \
|
/* Invoke the gemm micro-kernel. */ \
|
||||||
gemm_ukr \
|
gemm_ukr \
|
||||||
( \
|
( \
|
||||||
|
MR, \
|
||||||
|
NR, \
|
||||||
k, \
|
k, \
|
||||||
alpha_cast, \
|
alpha_cast, \
|
||||||
a1, \
|
a1, \
|
||||||
@@ -509,43 +491,20 @@ void PASTEMAC(ch,varname) \
|
|||||||
} \
|
} \
|
||||||
else if ( bli_is_strictly_below_diag_n( diagoffc_ij, m_cur, n_cur ) ) \
|
else if ( bli_is_strictly_below_diag_n( diagoffc_ij, m_cur, n_cur ) ) \
|
||||||
{ \
|
{ \
|
||||||
/* Handle interior and edge cases separately. */ \
|
/* Invoke the gemm micro-kernel. */ \
|
||||||
if ( m_cur == MR && n_cur == NR ) \
|
gemm_ukr \
|
||||||
{ \
|
( \
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
m_cur, \
|
||||||
gemm_ukr \
|
n_cur, \
|
||||||
( \
|
k, \
|
||||||
k, \
|
alpha_cast, \
|
||||||
alpha_cast, \
|
a1, \
|
||||||
a1, \
|
b1, \
|
||||||
b1, \
|
beta_cast, \
|
||||||
beta_cast, \
|
c11, rs_c, cs_c, \
|
||||||
c11, rs_c, cs_c, \
|
&aux, \
|
||||||
&aux, \
|
cntx \
|
||||||
cntx \
|
); \
|
||||||
); \
|
|
||||||
} \
|
|
||||||
else \
|
|
||||||
{ \
|
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
|
||||||
gemm_ukr \
|
|
||||||
( \
|
|
||||||
k, \
|
|
||||||
alpha_cast, \
|
|
||||||
a1, \
|
|
||||||
b1, \
|
|
||||||
zero, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
&aux, \
|
|
||||||
cntx \
|
|
||||||
); \
|
|
||||||
\
|
|
||||||
/* Scale the edge of C and add the result. */ \
|
|
||||||
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
beta_cast, \
|
|
||||||
c11, rs_c, cs_c ); \
|
|
||||||
} \
|
|
||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
|
|||||||
@@ -281,6 +281,9 @@ void PASTEMAC(ch,varname) \
|
|||||||
/* Save the imaginary stride of A and B to the auxinfo_t object. */ \
|
/* Save the imaginary stride of A and B to the auxinfo_t object. */ \
|
||||||
bli_auxinfo_set_is_a( is_a, &aux ); \
|
bli_auxinfo_set_is_a( is_a, &aux ); \
|
||||||
bli_auxinfo_set_is_b( is_b, &aux ); \
|
bli_auxinfo_set_is_b( is_b, &aux ); \
|
||||||
|
\
|
||||||
|
/* Save the desired output datatype (indicating no typecasting). */ \
|
||||||
|
/*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \
|
||||||
\
|
\
|
||||||
/* The 'thread' argument points to the thrinfo_t node for the 2nd (jr)
|
/* The 'thread' argument points to the thrinfo_t node for the 2nd (jr)
|
||||||
loop around the microkernel. Here we query the thrinfo_t node for the
|
loop around the microkernel. Here we query the thrinfo_t node for the
|
||||||
@@ -385,6 +388,8 @@ void PASTEMAC(ch,varname) \
|
|||||||
/* Invoke the gemm micro-kernel. */ \
|
/* Invoke the gemm micro-kernel. */ \
|
||||||
gemm_ukr \
|
gemm_ukr \
|
||||||
( \
|
( \
|
||||||
|
MR, \
|
||||||
|
NR, \
|
||||||
k, \
|
k, \
|
||||||
alpha_cast, \
|
alpha_cast, \
|
||||||
a1, \
|
a1, \
|
||||||
@@ -404,43 +409,20 @@ void PASTEMAC(ch,varname) \
|
|||||||
} \
|
} \
|
||||||
else if ( bli_is_strictly_above_diag_n( diagoffc_ij, m_cur, n_cur ) ) \
|
else if ( bli_is_strictly_above_diag_n( diagoffc_ij, m_cur, n_cur ) ) \
|
||||||
{ \
|
{ \
|
||||||
/* Handle interior and edge cases separately. */ \
|
/* Invoke the gemm micro-kernel. */ \
|
||||||
if ( m_cur == MR && n_cur == NR ) \
|
gemm_ukr \
|
||||||
{ \
|
( \
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
m_cur, \
|
||||||
gemm_ukr \
|
n_cur, \
|
||||||
( \
|
k, \
|
||||||
k, \
|
alpha_cast, \
|
||||||
alpha_cast, \
|
a1, \
|
||||||
a1, \
|
b1, \
|
||||||
b1, \
|
beta_cast, \
|
||||||
beta_cast, \
|
c11, rs_c, cs_c, \
|
||||||
c11, rs_c, cs_c, \
|
&aux, \
|
||||||
&aux, \
|
cntx \
|
||||||
cntx \
|
); \
|
||||||
); \
|
|
||||||
} \
|
|
||||||
else \
|
|
||||||
{ \
|
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
|
||||||
gemm_ukr \
|
|
||||||
( \
|
|
||||||
k, \
|
|
||||||
alpha_cast, \
|
|
||||||
a1, \
|
|
||||||
b1, \
|
|
||||||
zero, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
&aux, \
|
|
||||||
cntx \
|
|
||||||
); \
|
|
||||||
\
|
|
||||||
/* Scale the edge of C and add the result. */ \
|
|
||||||
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
beta_cast, \
|
|
||||||
c11, rs_c, cs_c ); \
|
|
||||||
} \
|
|
||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
@@ -512,43 +494,20 @@ void PASTEMAC(ch,varname) \
|
|||||||
And if we're strictly below the diagonal, we do nothing and
|
And if we're strictly below the diagonal, we do nothing and
|
||||||
continue. */ \
|
continue. */ \
|
||||||
{ \
|
{ \
|
||||||
/* Handle interior and edge cases separately. */ \
|
/* Invoke the gemm micro-kernel. */ \
|
||||||
if ( m_cur == MR && n_cur == NR ) \
|
gemm_ukr \
|
||||||
{ \
|
( \
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
m_cur, \
|
||||||
gemm_ukr \
|
n_cur, \
|
||||||
( \
|
k, \
|
||||||
k, \
|
alpha_cast, \
|
||||||
alpha_cast, \
|
a1, \
|
||||||
a1, \
|
b1, \
|
||||||
b1, \
|
beta_cast, \
|
||||||
beta_cast, \
|
c11, rs_c, cs_c, \
|
||||||
c11, rs_c, cs_c, \
|
&aux, \
|
||||||
&aux, \
|
cntx \
|
||||||
cntx \
|
); \
|
||||||
); \
|
|
||||||
} \
|
|
||||||
else \
|
|
||||||
{ \
|
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
|
||||||
gemm_ukr \
|
|
||||||
( \
|
|
||||||
k, \
|
|
||||||
alpha_cast, \
|
|
||||||
a1, \
|
|
||||||
b1, \
|
|
||||||
zero, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
&aux, \
|
|
||||||
cntx \
|
|
||||||
); \
|
|
||||||
\
|
|
||||||
/* Scale the edge of C and add the result. */ \
|
|
||||||
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
beta_cast, \
|
|
||||||
c11, rs_c, cs_c ); \
|
|
||||||
} \
|
|
||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
|
|||||||
@@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \
|
|||||||
function pointer type. */ \
|
function pointer type. */ \
|
||||||
PASTECH(ch,gemm_ukr_ft) \
|
PASTECH(ch,gemm_ukr_ft) \
|
||||||
gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
||||||
\
|
|
||||||
/* Temporary C buffer for edge cases. Note that the strides of this
|
|
||||||
temporary buffer are set so that they match the storage of the
|
|
||||||
original C matrix. For example, if C is column-stored, ct will be
|
|
||||||
column-stored as well. */ \
|
|
||||||
ctype ct[ BLIS_STACK_BUF_MAX_SIZE \
|
|
||||||
/ sizeof( ctype ) ] \
|
|
||||||
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
|
|
||||||
const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
|
||||||
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
|
||||||
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
|
||||||
\
|
\
|
||||||
ctype* restrict one = PASTEMAC(ch,1); \
|
ctype* restrict one = PASTEMAC(ch,1); \
|
||||||
ctype* restrict zero = PASTEMAC(ch,0); \
|
|
||||||
ctype* restrict a_cast = a; \
|
ctype* restrict a_cast = a; \
|
||||||
ctype* restrict b_cast = b; \
|
ctype* restrict b_cast = b; \
|
||||||
ctype* restrict c_cast = c; \
|
ctype* restrict c_cast = c; \
|
||||||
@@ -254,10 +242,6 @@ void PASTEMAC(ch,varname) \
|
|||||||
diagoffa = 0; \
|
diagoffa = 0; \
|
||||||
c_cast = c_cast + (i )*rs_c; \
|
c_cast = c_cast + (i )*rs_c; \
|
||||||
} \
|
} \
|
||||||
\
|
|
||||||
/* Clear the temporary C buffer in case it has any infs or NaNs. */ \
|
|
||||||
PASTEMAC(ch,set0s_mxn)( MR, NR, \
|
|
||||||
ct, rs_ct, cs_ct ); \
|
|
||||||
\
|
\
|
||||||
/* Compute number of primary and leftover components of the m and n
|
/* Compute number of primary and leftover components of the m and n
|
||||||
dimensions. */ \
|
dimensions. */ \
|
||||||
@@ -307,8 +291,8 @@ void PASTEMAC(ch,varname) \
|
|||||||
dim_t jr_inc; \
|
dim_t jr_inc; \
|
||||||
\
|
\
|
||||||
/* Determine the thread range and increment for the 2nd loop.
|
/* Determine the thread range and increment for the 2nd loop.
|
||||||
NOTE: The definition of bli_thread_range_jrir() will depend on whether
|
NOTE: The definition of bli_thread_range_jrir() will depend on whether
|
||||||
slab or round-robin partitioning was requested at configure-time. \
|
slab or round-robin partitioning was requested at configure-time. \
|
||||||
NOTE: Parallelism in the 1st loop is disabled for now. */ \
|
NOTE: Parallelism in the 1st loop is disabled for now. */ \
|
||||||
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \
|
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \
|
||||||
/*bli_thread_range_jrir_rr( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc );*/ \
|
/*bli_thread_range_jrir_rr( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc );*/ \
|
||||||
@@ -379,47 +363,20 @@ void PASTEMAC(ch,varname) \
|
|||||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
bli_auxinfo_set_next_a( a2, &aux ); \
|
||||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
bli_auxinfo_set_next_b( b2, &aux ); \
|
||||||
\
|
\
|
||||||
/* Handle interior and edge cases separately. */ \
|
/* Invoke the gemm micro-kernel. */ \
|
||||||
if ( m_cur == MR && n_cur == NR ) \
|
gemm_ukr \
|
||||||
{ \
|
( \
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
m_cur, \
|
||||||
gemm_ukr \
|
n_cur, \
|
||||||
( \
|
k_a1011, \
|
||||||
k_a1011, \
|
alpha_cast, \
|
||||||
alpha_cast, \
|
a1, \
|
||||||
a1, \
|
b1_i, \
|
||||||
b1_i, \
|
beta_cast, \
|
||||||
beta_cast, \
|
c11, rs_c, cs_c, \
|
||||||
c11, rs_c, cs_c, \
|
&aux, \
|
||||||
&aux, \
|
cntx \
|
||||||
cntx \
|
); \
|
||||||
); \
|
|
||||||
} \
|
|
||||||
else \
|
|
||||||
{ \
|
|
||||||
/* Copy edge elements of C to the temporary buffer. */ \
|
|
||||||
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
|
|
||||||
c11, rs_c, cs_c, \
|
|
||||||
ct, rs_ct, cs_ct ); \
|
|
||||||
\
|
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
|
||||||
gemm_ukr \
|
|
||||||
( \
|
|
||||||
k_a1011, \
|
|
||||||
alpha_cast, \
|
|
||||||
a1, \
|
|
||||||
b1_i, \
|
|
||||||
beta_cast, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
&aux, \
|
|
||||||
cntx \
|
|
||||||
); \
|
|
||||||
\
|
|
||||||
/* Copy the result to the edge of C. */ \
|
|
||||||
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
c11, rs_c, cs_c ); \
|
|
||||||
} \
|
|
||||||
/*}*/ \
|
/*}*/ \
|
||||||
\
|
\
|
||||||
a1 += ps_a_cur; \
|
a1 += ps_a_cur; \
|
||||||
@@ -446,42 +403,20 @@ void PASTEMAC(ch,varname) \
|
|||||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
bli_auxinfo_set_next_a( a2, &aux ); \
|
||||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
bli_auxinfo_set_next_b( b2, &aux ); \
|
||||||
\
|
\
|
||||||
/* Handle interior and edge cases separately. */ \
|
/* Invoke the gemm micro-kernel. */ \
|
||||||
if ( m_cur == MR && n_cur == NR ) \
|
gemm_ukr \
|
||||||
{ \
|
( \
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
m_cur, \
|
||||||
gemm_ukr \
|
n_cur, \
|
||||||
( \
|
k, \
|
||||||
k, \
|
alpha_cast, \
|
||||||
alpha_cast, \
|
a1, \
|
||||||
a1, \
|
b1, \
|
||||||
b1, \
|
one, \
|
||||||
one, \
|
c11, rs_c, cs_c, \
|
||||||
c11, rs_c, cs_c, \
|
&aux, \
|
||||||
&aux, \
|
cntx \
|
||||||
cntx \
|
); \
|
||||||
); \
|
|
||||||
} \
|
|
||||||
else \
|
|
||||||
{ \
|
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
|
||||||
gemm_ukr \
|
|
||||||
( \
|
|
||||||
k, \
|
|
||||||
alpha_cast, \
|
|
||||||
a1, \
|
|
||||||
b1, \
|
|
||||||
zero, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
&aux, \
|
|
||||||
cntx \
|
|
||||||
); \
|
|
||||||
\
|
|
||||||
/* Add the result to the edge of C. */ \
|
|
||||||
PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
c11, rs_c, cs_c ); \
|
|
||||||
} \
|
|
||||||
/*}*/ \
|
/*}*/ \
|
||||||
\
|
\
|
||||||
a1 += rstep_a; \
|
a1 += rstep_a; \
|
||||||
|
|||||||
@@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \
|
|||||||
function pointer type. */ \
|
function pointer type. */ \
|
||||||
PASTECH(ch,gemm_ukr_ft) \
|
PASTECH(ch,gemm_ukr_ft) \
|
||||||
gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
||||||
\
|
|
||||||
/* Temporary C buffer for edge cases. Note that the strides of this
|
|
||||||
temporary buffer are set so that they match the storage of the
|
|
||||||
original C matrix. For example, if C is column-stored, ct will be
|
|
||||||
column-stored as well. */ \
|
|
||||||
ctype ct[ BLIS_STACK_BUF_MAX_SIZE \
|
|
||||||
/ sizeof( ctype ) ] \
|
|
||||||
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
|
|
||||||
const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
|
||||||
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
|
||||||
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
|
||||||
\
|
\
|
||||||
ctype* restrict one = PASTEMAC(ch,1); \
|
ctype* restrict one = PASTEMAC(ch,1); \
|
||||||
ctype* restrict zero = PASTEMAC(ch,0); \
|
|
||||||
ctype* restrict a_cast = a; \
|
ctype* restrict a_cast = a; \
|
||||||
ctype* restrict b_cast = b; \
|
ctype* restrict b_cast = b; \
|
||||||
ctype* restrict c_cast = c; \
|
ctype* restrict c_cast = c; \
|
||||||
@@ -261,10 +249,6 @@ void PASTEMAC(ch,varname) \
|
|||||||
{ \
|
{ \
|
||||||
m = -diagoffa + k; \
|
m = -diagoffa + k; \
|
||||||
} \
|
} \
|
||||||
\
|
|
||||||
/* Clear the temporary C buffer in case it has any infs or NaNs. */ \
|
|
||||||
PASTEMAC(ch,set0s_mxn)( MR, NR, \
|
|
||||||
ct, rs_ct, cs_ct ); \
|
|
||||||
\
|
\
|
||||||
/* Compute number of primary and leftover components of the m and n
|
/* Compute number of primary and leftover components of the m and n
|
||||||
dimensions. */ \
|
dimensions. */ \
|
||||||
@@ -386,47 +370,20 @@ void PASTEMAC(ch,varname) \
|
|||||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
bli_auxinfo_set_next_a( a2, &aux ); \
|
||||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
bli_auxinfo_set_next_b( b2, &aux ); \
|
||||||
\
|
\
|
||||||
/* Handle interior and edge cases separately. */ \
|
/* Invoke the gemm micro-kernel. */ \
|
||||||
if ( m_cur == MR && n_cur == NR ) \
|
gemm_ukr \
|
||||||
{ \
|
( \
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
m_cur, \
|
||||||
gemm_ukr \
|
n_cur, \
|
||||||
( \
|
k_a1112, \
|
||||||
k_a1112, \
|
alpha_cast, \
|
||||||
alpha_cast, \
|
a1, \
|
||||||
a1, \
|
b1_i, \
|
||||||
b1_i, \
|
beta_cast, \
|
||||||
beta_cast, \
|
c11, rs_c, cs_c, \
|
||||||
c11, rs_c, cs_c, \
|
&aux, \
|
||||||
&aux, \
|
cntx \
|
||||||
cntx \
|
); \
|
||||||
); \
|
|
||||||
} \
|
|
||||||
else \
|
|
||||||
{ \
|
|
||||||
/* Copy edge elements of C to the temporary buffer. */ \
|
|
||||||
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
|
|
||||||
c11, rs_c, cs_c, \
|
|
||||||
ct, rs_ct, cs_ct ); \
|
|
||||||
\
|
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
|
||||||
gemm_ukr \
|
|
||||||
( \
|
|
||||||
k_a1112, \
|
|
||||||
alpha_cast, \
|
|
||||||
a1, \
|
|
||||||
b1_i, \
|
|
||||||
beta_cast, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
&aux, \
|
|
||||||
cntx \
|
|
||||||
); \
|
|
||||||
\
|
|
||||||
/* Copy the result to the edge of C. */ \
|
|
||||||
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
c11, rs_c, cs_c ); \
|
|
||||||
} \
|
|
||||||
/*}*/ \
|
/*}*/ \
|
||||||
\
|
\
|
||||||
a1 += ps_a_cur; \
|
a1 += ps_a_cur; \
|
||||||
@@ -453,42 +410,20 @@ void PASTEMAC(ch,varname) \
|
|||||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
bli_auxinfo_set_next_a( a2, &aux ); \
|
||||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
bli_auxinfo_set_next_b( b2, &aux ); \
|
||||||
\
|
\
|
||||||
/* Handle interior and edge cases separately. */ \
|
/* Invoke the gemm micro-kernel. */ \
|
||||||
if ( m_cur == MR && n_cur == NR ) \
|
gemm_ukr \
|
||||||
{ \
|
( \
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
m_cur, \
|
||||||
gemm_ukr \
|
n_cur, \
|
||||||
( \
|
k, \
|
||||||
k, \
|
alpha_cast, \
|
||||||
alpha_cast, \
|
a1, \
|
||||||
a1, \
|
b1, \
|
||||||
b1, \
|
one, \
|
||||||
one, \
|
c11, rs_c, cs_c, \
|
||||||
c11, rs_c, cs_c, \
|
&aux, \
|
||||||
&aux, \
|
cntx \
|
||||||
cntx \
|
); \
|
||||||
); \
|
|
||||||
} \
|
|
||||||
else \
|
|
||||||
{ \
|
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
|
||||||
gemm_ukr \
|
|
||||||
( \
|
|
||||||
k, \
|
|
||||||
alpha_cast, \
|
|
||||||
a1, \
|
|
||||||
b1, \
|
|
||||||
zero, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
&aux, \
|
|
||||||
cntx \
|
|
||||||
); \
|
|
||||||
\
|
|
||||||
/* Add the result to the edge of C. */ \
|
|
||||||
PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
c11, rs_c, cs_c ); \
|
|
||||||
} \
|
|
||||||
/*}*/ \
|
/*}*/ \
|
||||||
\
|
\
|
||||||
a1 += rstep_a; \
|
a1 += rstep_a; \
|
||||||
|
|||||||
@@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \
|
|||||||
function pointer type. */ \
|
function pointer type. */ \
|
||||||
PASTECH(ch,gemm_ukr_ft) \
|
PASTECH(ch,gemm_ukr_ft) \
|
||||||
gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
||||||
\
|
|
||||||
/* Temporary C buffer for edge cases. Note that the strides of this
|
|
||||||
temporary buffer are set so that they match the storage of the
|
|
||||||
original C matrix. For example, if C is column-stored, ct will be
|
|
||||||
column-stored as well. */ \
|
|
||||||
ctype ct[ BLIS_STACK_BUF_MAX_SIZE \
|
|
||||||
/ sizeof( ctype ) ] \
|
|
||||||
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
|
|
||||||
const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
|
||||||
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
|
||||||
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
|
||||||
\
|
\
|
||||||
ctype* restrict one = PASTEMAC(ch,1); \
|
ctype* restrict one = PASTEMAC(ch,1); \
|
||||||
ctype* restrict zero = PASTEMAC(ch,0); \
|
|
||||||
ctype* restrict a_cast = a; \
|
ctype* restrict a_cast = a; \
|
||||||
ctype* restrict b_cast = b; \
|
ctype* restrict b_cast = b; \
|
||||||
ctype* restrict c_cast = c; \
|
ctype* restrict c_cast = c; \
|
||||||
@@ -261,10 +249,6 @@ void PASTEMAC(ch,varname) \
|
|||||||
{ \
|
{ \
|
||||||
n = diagoffb + k; \
|
n = diagoffb + k; \
|
||||||
} \
|
} \
|
||||||
\
|
|
||||||
/* Clear the temporary C buffer in case it has any infs or NaNs. */ \
|
|
||||||
PASTEMAC(ch,set0s_mxn)( MR, NR, \
|
|
||||||
ct, rs_ct, cs_ct ); \
|
|
||||||
\
|
\
|
||||||
/* Compute number of primary and leftover components of the m and n
|
/* Compute number of primary and leftover components of the m and n
|
||||||
dimensions. */ \
|
dimensions. */ \
|
||||||
@@ -335,9 +319,9 @@ void PASTEMAC(ch,varname) \
|
|||||||
\
|
\
|
||||||
/* Determine the thread range and increment for the 2nd and 1st loops for
|
/* Determine the thread range and increment for the 2nd and 1st loops for
|
||||||
the initial rectangular region of B (if it exists).
|
the initial rectangular region of B (if it exists).
|
||||||
NOTE: The definition of bli_thread_range_jrir() will depend on whether
|
NOTE: The definition of bli_thread_range_jrir() will depend on whether
|
||||||
slab or round-robin partitioning was requested at configure-time. \
|
slab or round-robin partitioning was requested at configure-time. \
|
||||||
NOTE: Parallelism in the 1st loop is disabled for now. */ \
|
NOTE: Parallelism in the 1st loop is disabled for now. */ \
|
||||||
bli_thread_range_jrir( thread, n_iter_rct, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \
|
bli_thread_range_jrir( thread, n_iter_rct, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \
|
||||||
bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \
|
bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \
|
||||||
\
|
\
|
||||||
@@ -382,42 +366,20 @@ void PASTEMAC(ch,varname) \
|
|||||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
bli_auxinfo_set_next_a( a2, &aux ); \
|
||||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
bli_auxinfo_set_next_b( b2, &aux ); \
|
||||||
\
|
\
|
||||||
/* Handle interior and edge cases separately. */ \
|
/* Invoke the gemm micro-kernel. */ \
|
||||||
if ( m_cur == MR && n_cur == NR ) \
|
gemm_ukr \
|
||||||
{ \
|
( \
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
m_cur, \
|
||||||
gemm_ukr \
|
n_cur, \
|
||||||
( \
|
k, \
|
||||||
k, \
|
alpha_cast, \
|
||||||
alpha_cast, \
|
a1, \
|
||||||
a1, \
|
b1, \
|
||||||
b1, \
|
one, \
|
||||||
one, \
|
c11, rs_c, cs_c, \
|
||||||
c11, rs_c, cs_c, \
|
&aux, \
|
||||||
&aux, \
|
cntx \
|
||||||
cntx \
|
); \
|
||||||
); \
|
|
||||||
} \
|
|
||||||
else \
|
|
||||||
{ \
|
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
|
||||||
gemm_ukr \
|
|
||||||
( \
|
|
||||||
k, \
|
|
||||||
alpha_cast, \
|
|
||||||
a1, \
|
|
||||||
b1, \
|
|
||||||
zero, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
&aux, \
|
|
||||||
cntx \
|
|
||||||
); \
|
|
||||||
\
|
|
||||||
/* Add the result to the edge of C. */ \
|
|
||||||
PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
c11, rs_c, cs_c ); \
|
|
||||||
} \
|
|
||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
@@ -501,47 +463,20 @@ void PASTEMAC(ch,varname) \
|
|||||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
bli_auxinfo_set_next_a( a2, &aux ); \
|
||||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
bli_auxinfo_set_next_b( b2, &aux ); \
|
||||||
\
|
\
|
||||||
/* Handle interior and edge cases separately. */ \
|
/* Invoke the gemm micro-kernel. */ \
|
||||||
if ( m_cur == MR && n_cur == NR ) \
|
gemm_ukr \
|
||||||
{ \
|
( \
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
m_cur, \
|
||||||
gemm_ukr \
|
n_cur, \
|
||||||
( \
|
k_b1121, \
|
||||||
k_b1121, \
|
alpha_cast, \
|
||||||
alpha_cast, \
|
a1_i, \
|
||||||
a1_i, \
|
b1, \
|
||||||
b1, \
|
beta_cast, \
|
||||||
beta_cast, \
|
c11, rs_c, cs_c, \
|
||||||
c11, rs_c, cs_c, \
|
&aux, \
|
||||||
&aux, \
|
cntx \
|
||||||
cntx \
|
); \
|
||||||
); \
|
|
||||||
} \
|
|
||||||
else \
|
|
||||||
{ \
|
|
||||||
/* Copy edge elements of C to the temporary buffer. */ \
|
|
||||||
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
|
|
||||||
c11, rs_c, cs_c, \
|
|
||||||
ct, rs_ct, cs_ct ); \
|
|
||||||
\
|
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
|
||||||
gemm_ukr \
|
|
||||||
( \
|
|
||||||
k_b1121, \
|
|
||||||
alpha_cast, \
|
|
||||||
a1_i, \
|
|
||||||
b1, \
|
|
||||||
beta_cast, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
&aux, \
|
|
||||||
cntx \
|
|
||||||
); \
|
|
||||||
\
|
|
||||||
/* Copy the result to the edge of C. */ \
|
|
||||||
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
c11, rs_c, cs_c ); \
|
|
||||||
} \
|
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
a1 += rstep_a; \
|
a1 += rstep_a; \
|
||||||
|
|||||||
@@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \
|
|||||||
function pointer type. */ \
|
function pointer type. */ \
|
||||||
PASTECH(ch,gemm_ukr_ft) \
|
PASTECH(ch,gemm_ukr_ft) \
|
||||||
gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
||||||
\
|
|
||||||
/* Temporary C buffer for edge cases. Note that the strides of this
|
|
||||||
temporary buffer are set so that they match the storage of the
|
|
||||||
original C matrix. For example, if C is column-stored, ct will be
|
|
||||||
column-stored as well. */ \
|
|
||||||
ctype ct[ BLIS_STACK_BUF_MAX_SIZE \
|
|
||||||
/ sizeof( ctype ) ] \
|
|
||||||
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
|
|
||||||
const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \
|
|
||||||
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
|
||||||
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
|
||||||
\
|
\
|
||||||
ctype* restrict one = PASTEMAC(ch,1); \
|
ctype* restrict one = PASTEMAC(ch,1); \
|
||||||
ctype* restrict zero = PASTEMAC(ch,0); \
|
|
||||||
ctype* restrict a_cast = a; \
|
ctype* restrict a_cast = a; \
|
||||||
ctype* restrict b_cast = b; \
|
ctype* restrict b_cast = b; \
|
||||||
ctype* restrict c_cast = c; \
|
ctype* restrict c_cast = c; \
|
||||||
@@ -262,10 +250,6 @@ void PASTEMAC(ch,varname) \
|
|||||||
{ \
|
{ \
|
||||||
k = -diagoffb + n; \
|
k = -diagoffb + n; \
|
||||||
} \
|
} \
|
||||||
\
|
|
||||||
/* Clear the temporary C buffer in case it has any infs or NaNs. */ \
|
|
||||||
PASTEMAC(ch,set0s_mxn)( MR, NR, \
|
|
||||||
ct, rs_ct, cs_ct ); \
|
|
||||||
\
|
\
|
||||||
/* Compute number of primary and leftover components of the m and n
|
/* Compute number of primary and leftover components of the m and n
|
||||||
dimensions. */ \
|
dimensions. */ \
|
||||||
@@ -410,47 +394,20 @@ void PASTEMAC(ch,varname) \
|
|||||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
bli_auxinfo_set_next_a( a2, &aux ); \
|
||||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
bli_auxinfo_set_next_b( b2, &aux ); \
|
||||||
\
|
\
|
||||||
/* Handle interior and edge cases separately. */ \
|
/* Invoke the gemm micro-kernel. */ \
|
||||||
if ( m_cur == MR && n_cur == NR ) \
|
gemm_ukr \
|
||||||
{ \
|
( \
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
m_cur, \
|
||||||
gemm_ukr \
|
n_cur, \
|
||||||
( \
|
k_b0111, \
|
||||||
k_b0111, \
|
alpha_cast, \
|
||||||
alpha_cast, \
|
a1_i, \
|
||||||
a1_i, \
|
b1, \
|
||||||
b1, \
|
beta_cast, \
|
||||||
beta_cast, \
|
c11, rs_c, cs_c, \
|
||||||
c11, rs_c, cs_c, \
|
&aux, \
|
||||||
&aux, \
|
cntx \
|
||||||
cntx \
|
); \
|
||||||
); \
|
|
||||||
} \
|
|
||||||
else \
|
|
||||||
{ \
|
|
||||||
/* Copy edge elements of C to the temporary buffer. */ \
|
|
||||||
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
|
|
||||||
c11, rs_c, cs_c, \
|
|
||||||
ct, rs_ct, cs_ct ); \
|
|
||||||
\
|
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
|
||||||
gemm_ukr \
|
|
||||||
( \
|
|
||||||
k_b0111, \
|
|
||||||
alpha_cast, \
|
|
||||||
a1_i, \
|
|
||||||
b1, \
|
|
||||||
beta_cast, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
&aux, \
|
|
||||||
cntx \
|
|
||||||
); \
|
|
||||||
\
|
|
||||||
/* Copy the result to the edge of C. */ \
|
|
||||||
PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
c11, rs_c, cs_c ); \
|
|
||||||
} \
|
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
a1 += rstep_a; \
|
a1 += rstep_a; \
|
||||||
@@ -476,9 +433,9 @@ void PASTEMAC(ch,varname) \
|
|||||||
bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \
|
bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \
|
||||||
\
|
\
|
||||||
/* Advance the start and end iteration offsets for the rectangular region
|
/* Advance the start and end iteration offsets for the rectangular region
|
||||||
by the number of iterations used for the triangular region. */ \
|
by the number of iterations used for the triangular region. */ \
|
||||||
jr_start += n_iter_tri; \
|
jr_start += n_iter_tri; \
|
||||||
jr_end += n_iter_tri; \
|
jr_end += n_iter_tri; \
|
||||||
jb0 = n_iter_tri; \
|
jb0 = n_iter_tri; \
|
||||||
\
|
\
|
||||||
/* Save the resulting value of b1 from the previous loop since it represents
|
/* Save the resulting value of b1 from the previous loop since it represents
|
||||||
@@ -496,7 +453,7 @@ void PASTEMAC(ch,varname) \
|
|||||||
the starting address of the rectangular region (which is already
|
the starting address of the rectangular region (which is already
|
||||||
n_iter_tri logical iterations through B). */ \
|
n_iter_tri logical iterations through B). */ \
|
||||||
b1 = b_cast + (j-jb0) * cstep_b; \
|
b1 = b_cast + (j-jb0) * cstep_b; \
|
||||||
c1 = c_cast + j * cstep_c; \
|
c1 = c_cast + j * cstep_c; \
|
||||||
\
|
\
|
||||||
n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \
|
n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \
|
||||||
\
|
\
|
||||||
@@ -533,42 +490,20 @@ void PASTEMAC(ch,varname) \
|
|||||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
bli_auxinfo_set_next_a( a2, &aux ); \
|
||||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
bli_auxinfo_set_next_b( b2, &aux ); \
|
||||||
\
|
\
|
||||||
/* Handle interior and edge cases separately. */ \
|
/* Invoke the gemm micro-kernel. */ \
|
||||||
if ( m_cur == MR && n_cur == NR ) \
|
gemm_ukr \
|
||||||
{ \
|
( \
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
m_cur, \
|
||||||
gemm_ukr \
|
n_cur, \
|
||||||
( \
|
k, \
|
||||||
k, \
|
alpha_cast, \
|
||||||
alpha_cast, \
|
a1, \
|
||||||
a1, \
|
b1, \
|
||||||
b1, \
|
one, \
|
||||||
one, \
|
c11, rs_c, cs_c, \
|
||||||
c11, rs_c, cs_c, \
|
&aux, \
|
||||||
&aux, \
|
cntx \
|
||||||
cntx \
|
); \
|
||||||
); \
|
|
||||||
} \
|
|
||||||
else \
|
|
||||||
{ \
|
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
|
||||||
gemm_ukr \
|
|
||||||
( \
|
|
||||||
k, \
|
|
||||||
alpha_cast, \
|
|
||||||
a1, \
|
|
||||||
b1, \
|
|
||||||
zero, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
&aux, \
|
|
||||||
cntx \
|
|
||||||
); \
|
|
||||||
\
|
|
||||||
/* Add the result to the edge of C. */ \
|
|
||||||
PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
c11, rs_c, cs_c ); \
|
|
||||||
} \
|
|
||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
|
|||||||
@@ -40,27 +40,30 @@ cntl_t* bli_trsm_cntl_create
|
|||||||
rntm_t* rntm,
|
rntm_t* rntm,
|
||||||
side_t side,
|
side_t side,
|
||||||
pack_t schema_a,
|
pack_t schema_a,
|
||||||
pack_t schema_b
|
pack_t schema_b,
|
||||||
|
void_fp ker
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
if ( bli_is_left( side ) )
|
if ( bli_is_left( side ) )
|
||||||
return bli_trsm_l_cntl_create( rntm, schema_a, schema_b );
|
return bli_trsm_l_cntl_create( rntm, schema_a, schema_b, ker );
|
||||||
else
|
else
|
||||||
return bli_trsm_r_cntl_create( rntm, schema_a, schema_b );
|
return bli_trsm_r_cntl_create( rntm, schema_a, schema_b, ker );
|
||||||
}
|
}
|
||||||
|
|
||||||
cntl_t* bli_trsm_l_cntl_create
|
cntl_t* bli_trsm_l_cntl_create
|
||||||
(
|
(
|
||||||
rntm_t* rntm,
|
rntm_t* rntm,
|
||||||
pack_t schema_a,
|
pack_t schema_a,
|
||||||
pack_t schema_b
|
pack_t schema_b,
|
||||||
|
void_fp ker
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
void_fp macro_kernel_p;
|
void_fp macro_kernel_p;
|
||||||
|
|
||||||
// Use the function pointer to the macrokernels that use slab
|
// Set the default macrokernel. If a non-NULL kernel function pointer is
|
||||||
// assignment of micropanels to threads in the jr and ir loops.
|
// passed in, we use that instead.
|
||||||
macro_kernel_p = bli_trsm_xx_ker_var2;
|
macro_kernel_p = bli_trsm_xx_ker_var2;
|
||||||
|
if ( ker ) macro_kernel_p = ker;
|
||||||
|
|
||||||
const opid_t family = BLIS_TRSM;
|
const opid_t family = BLIS_TRSM;
|
||||||
|
|
||||||
@@ -202,11 +205,15 @@ cntl_t* bli_trsm_r_cntl_create
|
|||||||
(
|
(
|
||||||
rntm_t* rntm,
|
rntm_t* rntm,
|
||||||
pack_t schema_a,
|
pack_t schema_a,
|
||||||
pack_t schema_b
|
pack_t schema_b,
|
||||||
|
void_fp ker
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
// NOTE: trsm macrokernels are presently disabled for right-side execution.
|
// NOTE: trsm macrokernels are presently disabled for right-side execution.
|
||||||
|
// Set the default macrokernel. If a non-NULL kernel function pointer is
|
||||||
|
// passed in, we use that instead.
|
||||||
void_fp macro_kernel_p = bli_trsm_xx_ker_var2;
|
void_fp macro_kernel_p = bli_trsm_xx_ker_var2;
|
||||||
|
if ( ker ) macro_kernel_p = ker;
|
||||||
|
|
||||||
const opid_t family = BLIS_TRSM;
|
const opid_t family = BLIS_TRSM;
|
||||||
|
|
||||||
|
|||||||
@@ -38,21 +38,24 @@ cntl_t* bli_trsm_cntl_create
|
|||||||
rntm_t* rntm,
|
rntm_t* rntm,
|
||||||
side_t side,
|
side_t side,
|
||||||
pack_t schema_a,
|
pack_t schema_a,
|
||||||
pack_t schema_b
|
pack_t schema_b,
|
||||||
|
void_fp ker
|
||||||
);
|
);
|
||||||
|
|
||||||
cntl_t* bli_trsm_l_cntl_create
|
cntl_t* bli_trsm_l_cntl_create
|
||||||
(
|
(
|
||||||
rntm_t* rntm,
|
rntm_t* rntm,
|
||||||
pack_t schema_a,
|
pack_t schema_a,
|
||||||
pack_t schema_b
|
pack_t schema_b,
|
||||||
|
void_fp ker
|
||||||
);
|
);
|
||||||
|
|
||||||
cntl_t* bli_trsm_r_cntl_create
|
cntl_t* bli_trsm_r_cntl_create
|
||||||
(
|
(
|
||||||
rntm_t* rntm,
|
rntm_t* rntm,
|
||||||
pack_t schema_a,
|
pack_t schema_a,
|
||||||
pack_t schema_b
|
pack_t schema_b,
|
||||||
|
void_fp ker
|
||||||
);
|
);
|
||||||
|
|
||||||
void bli_trsm_cntl_free
|
void bli_trsm_cntl_free
|
||||||
|
|||||||
@@ -183,7 +183,6 @@ void PASTEMAC(ch,varname) \
|
|||||||
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
||||||
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
||||||
\
|
\
|
||||||
ctype* restrict zero = PASTEMAC(ch,0); \
|
|
||||||
ctype* restrict minus_one = PASTEMAC(ch,m1); \
|
ctype* restrict minus_one = PASTEMAC(ch,m1); \
|
||||||
ctype* restrict a_cast = a; \
|
ctype* restrict a_cast = a; \
|
||||||
ctype* restrict b_cast = b; \
|
ctype* restrict b_cast = b; \
|
||||||
@@ -470,43 +469,20 @@ void PASTEMAC(ch,varname) \
|
|||||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
bli_auxinfo_set_next_a( a2, &aux ); \
|
||||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
bli_auxinfo_set_next_b( b2, &aux ); \
|
||||||
\
|
\
|
||||||
/* Handle interior and edge cases separately. */ \
|
/* Invoke the gemm micro-kernel. */ \
|
||||||
if ( m_cur == MR && n_cur == NR ) \
|
gemm_ukr \
|
||||||
{ \
|
( \
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
m_cur, \
|
||||||
gemm_ukr \
|
n_cur, \
|
||||||
( \
|
k, \
|
||||||
k, \
|
minus_one, \
|
||||||
minus_one, \
|
a1, \
|
||||||
a1, \
|
b1, \
|
||||||
b1, \
|
alpha2_cast, \
|
||||||
alpha2_cast, \
|
c11, rs_c, cs_c, \
|
||||||
c11, rs_c, cs_c, \
|
&aux, \
|
||||||
&aux, \
|
cntx \
|
||||||
cntx \
|
); \
|
||||||
); \
|
|
||||||
} \
|
|
||||||
else \
|
|
||||||
{ \
|
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
|
||||||
gemm_ukr \
|
|
||||||
( \
|
|
||||||
k, \
|
|
||||||
minus_one, \
|
|
||||||
a1, \
|
|
||||||
b1, \
|
|
||||||
zero, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
&aux, \
|
|
||||||
cntx \
|
|
||||||
); \
|
|
||||||
\
|
|
||||||
/* Add the result to the edge of C. */ \
|
|
||||||
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
alpha2_cast, \
|
|
||||||
c11, rs_c, cs_c ); \
|
|
||||||
} \
|
|
||||||
\
|
\
|
||||||
a1 += rstep_a; \
|
a1 += rstep_a; \
|
||||||
} \
|
} \
|
||||||
|
|||||||
@@ -183,7 +183,6 @@ void PASTEMAC(ch,varname) \
|
|||||||
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
||||||
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
||||||
\
|
\
|
||||||
ctype* restrict zero = PASTEMAC(ch,0); \
|
|
||||||
ctype* restrict minus_one = PASTEMAC(ch,m1); \
|
ctype* restrict minus_one = PASTEMAC(ch,m1); \
|
||||||
ctype* restrict a_cast = a; \
|
ctype* restrict a_cast = a; \
|
||||||
ctype* restrict b_cast = b; \
|
ctype* restrict b_cast = b; \
|
||||||
@@ -480,43 +479,20 @@ void PASTEMAC(ch,varname) \
|
|||||||
bli_auxinfo_set_next_a( a2, &aux ); \
|
bli_auxinfo_set_next_a( a2, &aux ); \
|
||||||
bli_auxinfo_set_next_b( b2, &aux ); \
|
bli_auxinfo_set_next_b( b2, &aux ); \
|
||||||
\
|
\
|
||||||
/* Handle interior and edge cases separately. */ \
|
/* Invoke the gemm micro-kernel. */ \
|
||||||
if ( m_cur == MR && n_cur == NR ) \
|
gemm_ukr \
|
||||||
{ \
|
( \
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
m_cur, \
|
||||||
gemm_ukr \
|
n_cur, \
|
||||||
( \
|
k, \
|
||||||
k, \
|
minus_one, \
|
||||||
minus_one, \
|
a1, \
|
||||||
a1, \
|
b1, \
|
||||||
b1, \
|
alpha2_cast, \
|
||||||
alpha2_cast, \
|
c11, rs_c, cs_c, \
|
||||||
c11, rs_c, cs_c, \
|
&aux, \
|
||||||
&aux, \
|
cntx \
|
||||||
cntx \
|
); \
|
||||||
); \
|
|
||||||
} \
|
|
||||||
else \
|
|
||||||
{ \
|
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
|
||||||
gemm_ukr \
|
|
||||||
( \
|
|
||||||
k, \
|
|
||||||
minus_one, \
|
|
||||||
a1, \
|
|
||||||
b1, \
|
|
||||||
zero, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
&aux, \
|
|
||||||
cntx \
|
|
||||||
); \
|
|
||||||
\
|
|
||||||
/* Add the result to the edge of C. */ \
|
|
||||||
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
alpha2_cast, \
|
|
||||||
c11, rs_c, cs_c ); \
|
|
||||||
} \
|
|
||||||
\
|
\
|
||||||
a1 += rstep_a; \
|
a1 += rstep_a; \
|
||||||
} \
|
} \
|
||||||
|
|||||||
@@ -188,7 +188,6 @@ void PASTEMAC(ch,varname) \
|
|||||||
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
||||||
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
||||||
\
|
\
|
||||||
ctype* restrict zero = PASTEMAC(ch,0); \
|
|
||||||
ctype* restrict minus_one = PASTEMAC(ch,m1); \
|
ctype* restrict minus_one = PASTEMAC(ch,m1); \
|
||||||
ctype* restrict a_cast = a; \
|
ctype* restrict a_cast = a; \
|
||||||
ctype* restrict b_cast = b; \
|
ctype* restrict b_cast = b; \
|
||||||
@@ -499,43 +498,20 @@ void PASTEMAC(ch,varname) \
|
|||||||
bli_auxinfo_set_next_a( b2, &aux ); \
|
bli_auxinfo_set_next_a( b2, &aux ); \
|
||||||
bli_auxinfo_set_next_b( a2, &aux ); \
|
bli_auxinfo_set_next_b( a2, &aux ); \
|
||||||
\
|
\
|
||||||
/* Handle interior and edge cases separately. */ \
|
/* Invoke the gemm micro-kernel. */ \
|
||||||
if ( m_cur == MR && n_cur == NR ) \
|
gemm_ukr \
|
||||||
{ \
|
( \
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
m_cur, \
|
||||||
gemm_ukr \
|
n_cur, \
|
||||||
( \
|
k, \
|
||||||
k, \
|
minus_one, \
|
||||||
minus_one, \
|
b1, \
|
||||||
b1, \
|
a1, \
|
||||||
a1, \
|
alpha2_cast, \
|
||||||
alpha2_cast, \
|
c11, cs_c, rs_c, \
|
||||||
c11, cs_c, rs_c, \
|
&aux, \
|
||||||
&aux, \
|
cntx \
|
||||||
cntx \
|
); \
|
||||||
); \
|
|
||||||
} \
|
|
||||||
else \
|
|
||||||
{ \
|
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
|
||||||
gemm_ukr \
|
|
||||||
( \
|
|
||||||
k, \
|
|
||||||
minus_one, \
|
|
||||||
b1, \
|
|
||||||
a1, \
|
|
||||||
zero, \
|
|
||||||
ct, cs_ct, rs_ct, \
|
|
||||||
&aux, \
|
|
||||||
cntx \
|
|
||||||
); \
|
|
||||||
\
|
|
||||||
/* Add the result to the edge of C. */ \
|
|
||||||
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
alpha2_cast, \
|
|
||||||
c11, rs_c, cs_c ); \
|
|
||||||
} \
|
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
a1 += rstep_a; \
|
a1 += rstep_a; \
|
||||||
|
|||||||
@@ -188,7 +188,6 @@ void PASTEMAC(ch,varname) \
|
|||||||
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
const inc_t rs_ct = ( col_pref ? 1 : NR ); \
|
||||||
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
const inc_t cs_ct = ( col_pref ? MR : 1 ); \
|
||||||
\
|
\
|
||||||
ctype* restrict zero = PASTEMAC(ch,0); \
|
|
||||||
ctype* restrict minus_one = PASTEMAC(ch,m1); \
|
ctype* restrict minus_one = PASTEMAC(ch,m1); \
|
||||||
ctype* restrict a_cast = a; \
|
ctype* restrict a_cast = a; \
|
||||||
ctype* restrict b_cast = b; \
|
ctype* restrict b_cast = b; \
|
||||||
@@ -492,43 +491,20 @@ void PASTEMAC(ch,varname) \
|
|||||||
bli_auxinfo_set_next_a( b2, &aux ); \
|
bli_auxinfo_set_next_a( b2, &aux ); \
|
||||||
bli_auxinfo_set_next_b( a2, &aux ); \
|
bli_auxinfo_set_next_b( a2, &aux ); \
|
||||||
\
|
\
|
||||||
/* Handle interior and edge cases separately. */ \
|
/* Invoke the gemm micro-kernel. */ \
|
||||||
if ( m_cur == MR && n_cur == NR ) \
|
gemm_ukr \
|
||||||
{ \
|
( \
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
m_cur, \
|
||||||
gemm_ukr \
|
n_cur, \
|
||||||
( \
|
k, \
|
||||||
k, \
|
minus_one, \
|
||||||
minus_one, \
|
b1, \
|
||||||
b1, \
|
a1, \
|
||||||
a1, \
|
alpha2_cast, \
|
||||||
alpha2_cast, \
|
c11, cs_c, rs_c, \
|
||||||
c11, cs_c, rs_c, \
|
&aux, \
|
||||||
&aux, \
|
cntx \
|
||||||
cntx \
|
); \
|
||||||
); \
|
|
||||||
} \
|
|
||||||
else \
|
|
||||||
{ \
|
|
||||||
/* Invoke the gemm micro-kernel. */ \
|
|
||||||
gemm_ukr \
|
|
||||||
( \
|
|
||||||
k, \
|
|
||||||
minus_one, \
|
|
||||||
b1, \
|
|
||||||
a1, \
|
|
||||||
zero, \
|
|
||||||
ct, cs_ct, rs_ct, \
|
|
||||||
&aux, \
|
|
||||||
cntx \
|
|
||||||
); \
|
|
||||||
\
|
|
||||||
/* Add the result to the edge of C. */ \
|
|
||||||
PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \
|
|
||||||
ct, rs_ct, cs_ct, \
|
|
||||||
alpha2_cast, \
|
|
||||||
c11, rs_c, cs_c ); \
|
|
||||||
} \
|
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
a1 += rstep_a; \
|
a1 += rstep_a; \
|
||||||
|
|||||||
@@ -74,6 +74,15 @@ BLIS_INLINE inc_t bli_auxinfo_ps_b( auxinfo_t* ai )
|
|||||||
return ai->ps_b;
|
return ai->ps_b;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BLIS_INLINE void_fp bli_auxinfo_ukr( auxinfo_t* ai )
|
||||||
|
{
|
||||||
|
return ai->ukr;
|
||||||
|
}
|
||||||
|
BLIS_INLINE void* bli_auxinfo_params( auxinfo_t* ai )
|
||||||
|
{
|
||||||
|
return ai->params;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// auxinfo_t field modification
|
// auxinfo_t field modification
|
||||||
|
|
||||||
@@ -118,5 +127,14 @@ BLIS_INLINE void bli_auxinfo_set_ps_b( inc_t ps, auxinfo_t* ai )
|
|||||||
ai->ps_b = ps;
|
ai->ps_b = ps;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
BLIS_INLINE void bli_auxinfo_set_ukr( void_fp ukr, auxinfo_t* ai )
|
||||||
|
{
|
||||||
|
ai->ukr = ukr;
|
||||||
|
}
|
||||||
|
BLIS_INLINE void bli_auxinfo_set_params( void* params, auxinfo_t* ai )
|
||||||
|
{
|
||||||
|
ai->params = params;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|||||||
109
frame/include/bli_edge_case_macro_defs.h
Normal file
109
frame/include/bli_edge_case_macro_defs.h
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
/*
|
||||||
|
|
||||||
|
BLIS
|
||||||
|
An object-based framework for developing high-performance BLAS-like
|
||||||
|
libraries.
|
||||||
|
|
||||||
|
Copyright (C) 2021, The University of Texas at Austin
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are
|
||||||
|
met:
|
||||||
|
- Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
- Redistributions in binary form must reproduce the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer in the
|
||||||
|
documentation and/or other materials provided with the distribution.
|
||||||
|
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived
|
||||||
|
from this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef BLIS_EDGE_CASE_MACRO_DEFS_H
|
||||||
|
#define BLIS_EDGE_CASE_MACRO_DEFS_H
|
||||||
|
|
||||||
|
|
||||||
|
// Helper macros for edge-case handling within gemm microkernels.
|
||||||
|
|
||||||
|
#define GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major) \
|
||||||
|
\
|
||||||
|
PASTEMAC(ch,ctype)* restrict _beta = beta; \
|
||||||
|
PASTEMAC(ch,ctype)* restrict _c = c; \
|
||||||
|
const inc_t _rs_c = rs_c; \
|
||||||
|
const inc_t _cs_c = cs_c; \
|
||||||
|
PASTEMAC(ch,ctype) _ct[ BLIS_STACK_BUF_MAX_SIZE / sizeof( PASTEMAC(ch,type) ) ] \
|
||||||
|
__attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \
|
||||||
|
const inc_t _rs_ct = row_major ? nr : 1; \
|
||||||
|
const inc_t _cs_ct = row_major ? 1 : mr;
|
||||||
|
|
||||||
|
#define GEMM_UKR_SETUP_CT_POST(ch) \
|
||||||
|
\
|
||||||
|
PASTEMAC(ch,ctype) _zero; \
|
||||||
|
PASTEMAC(ch,set0s)( _zero ); \
|
||||||
|
\
|
||||||
|
if ( _use_ct ) \
|
||||||
|
{ \
|
||||||
|
c = _ct; \
|
||||||
|
rs_c = _rs_ct; \
|
||||||
|
cs_c = _cs_ct; \
|
||||||
|
beta = &_zero; \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define GEMM_UKR_SETUP_CT(ch,mr,nr,row_major) \
|
||||||
|
\
|
||||||
|
GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \
|
||||||
|
const bool _use_ct = ( row_major ? cs_c != 1 : rs_c != 1 ) || \
|
||||||
|
m != mr || n != nr; \
|
||||||
|
GEMM_UKR_SETUP_CT_POST(ch);
|
||||||
|
|
||||||
|
#define GEMM_UKR_SETUP_CT_AMBI(ch,mr,nr,row_major) \
|
||||||
|
\
|
||||||
|
GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \
|
||||||
|
const bool _use_ct = ( cs_c != 1 && rs_c != 1 ) || \
|
||||||
|
m != mr || n != nr; \
|
||||||
|
GEMM_UKR_SETUP_CT_POST(ch);
|
||||||
|
|
||||||
|
#define GEMM_UKR_SETUP_CT_ANY(ch,mr,nr,row_major) \
|
||||||
|
\
|
||||||
|
GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \
|
||||||
|
const bool _use_ct = m != mr || n != nr; \
|
||||||
|
GEMM_UKR_SETUP_CT_POST(ch);
|
||||||
|
|
||||||
|
#define GEMM_UKR_SETUP_CT_ALIGNED(ch,mr,nr,row_major,alignment) \
|
||||||
|
\
|
||||||
|
GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \
|
||||||
|
const bool _use_ct = ( row_major ? cs_c != 1 : rs_c != 1 ) || \
|
||||||
|
m != mr || n != nr || \
|
||||||
|
( (uintptr_t)_c % alignment ) || \
|
||||||
|
( ( ( row_major ? _rs_c : _cs_c )*sizeof( PASTEMAC(ch,ctype) ) ) % alignment ); \
|
||||||
|
GEMM_UKR_SETUP_CT_POST(ch);
|
||||||
|
|
||||||
|
#define GEMM_UKR_FLUSH_CT(ch) \
|
||||||
|
\
|
||||||
|
if ( _use_ct ) \
|
||||||
|
{ \
|
||||||
|
PASTEMAC(ch,xpbys_mxn) \
|
||||||
|
( \
|
||||||
|
m, n, \
|
||||||
|
_ct, _rs_ct, _cs_ct, \
|
||||||
|
_beta, \
|
||||||
|
_c, _rs_c, _cs_c \
|
||||||
|
); \
|
||||||
|
} \
|
||||||
|
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
@@ -98,6 +98,7 @@
|
|||||||
#include "bli_gentprot_macro_defs.h"
|
#include "bli_gentprot_macro_defs.h"
|
||||||
|
|
||||||
#include "bli_misc_macro_defs.h"
|
#include "bli_misc_macro_defs.h"
|
||||||
|
#include "bli_edge_case_macro_defs.h"
|
||||||
#include "bli_param_macro_defs.h"
|
#include "bli_param_macro_defs.h"
|
||||||
#include "bli_obj_macro_defs.h"
|
#include "bli_obj_macro_defs.h"
|
||||||
#include "bli_complex_macro_defs.h"
|
#include "bli_complex_macro_defs.h"
|
||||||
|
|||||||
@@ -1144,6 +1144,13 @@ typedef struct
|
|||||||
inc_t ps_a;
|
inc_t ps_a;
|
||||||
inc_t ps_b;
|
inc_t ps_b;
|
||||||
|
|
||||||
|
// The type to convert to on output.
|
||||||
|
//num_t dt_on_output;
|
||||||
|
|
||||||
|
// (Virtual) microkernel address and additional parameters.
|
||||||
|
void_fp ukr;
|
||||||
|
void* params;
|
||||||
|
|
||||||
} auxinfo_t;
|
} auxinfo_t;
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -42,9 +42,13 @@
|
|||||||
// 2vx10 microkernels.
|
// 2vx10 microkernels.
|
||||||
#include "armsve_asm_2vx10cmplx.h"
|
#include "armsve_asm_2vx10cmplx.h"
|
||||||
|
|
||||||
|
#include "arm_sve.h"
|
||||||
|
|
||||||
void bli_cgemm_armsve_asm_2vx10_unindexed
|
void bli_cgemm_armsve_asm_2vx10_unindexed
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
scomplex* restrict alpha,
|
scomplex* restrict alpha,
|
||||||
scomplex* restrict a,
|
scomplex* restrict a,
|
||||||
scomplex* restrict b,
|
scomplex* restrict b,
|
||||||
@@ -59,12 +63,15 @@ void bli_cgemm_armsve_asm_2vx10_unindexed
|
|||||||
|
|
||||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||||
// different size than is expected by load instructions.
|
// different size than is expected by load instructions.
|
||||||
uint64_t k_mker = k0 / 4;
|
uint64_t k_mker = k / 4;
|
||||||
uint64_t k_left = k0 % 4;
|
uint64_t k_left = k % 4;
|
||||||
uint64_t rs_c = rs_c0;
|
uint64_t rs_c = rs_c0;
|
||||||
uint64_t cs_c = cs_c0;
|
uint64_t cs_c = cs_c0;
|
||||||
uint64_t info = 0;
|
uint64_t info = 0;
|
||||||
|
|
||||||
|
uint64_t mr = svcntw();
|
||||||
|
GEMM_UKR_SETUP_CT( c, mr, 10, false );
|
||||||
|
|
||||||
__asm__ volatile (
|
__asm__ volatile (
|
||||||
// " ldr x0, %[a] \n\t"
|
// " ldr x0, %[a] \n\t"
|
||||||
// " ldr x1, %[b] \n\t"
|
// " ldr x1, %[b] \n\t"
|
||||||
@@ -310,5 +317,7 @@ GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z28,%2,%4,x16)
|
|||||||
"z24","z25","z26","z27",
|
"z24","z25","z26","z27",
|
||||||
"z28","z29","z30","z31"
|
"z28","z29","z30","z31"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
GEMM_UKR_FLUSH_CT( c );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -42,9 +42,13 @@
|
|||||||
// 2vx10 microkernels.
|
// 2vx10 microkernels.
|
||||||
#include "armsve_asm_2vx10.h"
|
#include "armsve_asm_2vx10.h"
|
||||||
|
|
||||||
|
#include "arm_sve.h"
|
||||||
|
|
||||||
void bli_dgemm_armsve_asm_2vx10_unindexed
|
void bli_dgemm_armsve_asm_2vx10_unindexed
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
double* restrict alpha,
|
double* restrict alpha,
|
||||||
double* restrict a,
|
double* restrict a,
|
||||||
double* restrict b,
|
double* restrict b,
|
||||||
@@ -59,11 +63,14 @@ void bli_dgemm_armsve_asm_2vx10_unindexed
|
|||||||
|
|
||||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||||
// different size than is expected by load instructions.
|
// different size than is expected by load instructions.
|
||||||
uint64_t k_mker = k0 / 4;
|
uint64_t k_mker = k / 4;
|
||||||
uint64_t k_left = k0 % 4;
|
uint64_t k_left = k % 4;
|
||||||
uint64_t rs_c = rs_c0;
|
uint64_t rs_c = rs_c0;
|
||||||
uint64_t cs_c = cs_c0;
|
uint64_t cs_c = cs_c0;
|
||||||
|
|
||||||
|
uint64_t mr = 2*svcntd();
|
||||||
|
GEMM_UKR_SETUP_CT( d, mr, 10, false );
|
||||||
|
|
||||||
__asm__ volatile (
|
__asm__ volatile (
|
||||||
" ldr x0, %[a] \n\t"
|
" ldr x0, %[a] \n\t"
|
||||||
" ldr x1, %[b] \n\t"
|
" ldr x1, %[b] \n\t"
|
||||||
@@ -324,5 +331,7 @@ GEMM_C_STORE_UKER_G(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z30,p0,p0,x5,x7,x8,x
|
|||||||
"z24","z25","z26","z27",
|
"z24","z25","z26","z27",
|
||||||
"z28","z29","z30","z31"
|
"z28","z29","z30","z31"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
GEMM_UKR_FLUSH_CT( d );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -42,9 +42,13 @@
|
|||||||
// 2vx10 microkernels.
|
// 2vx10 microkernels.
|
||||||
#include "armsve_asm_2vx10.h"
|
#include "armsve_asm_2vx10.h"
|
||||||
|
|
||||||
|
#include "arm_sve.h"
|
||||||
|
|
||||||
void bli_sgemm_armsve_asm_2vx10_unindexed
|
void bli_sgemm_armsve_asm_2vx10_unindexed
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
float* restrict alpha,
|
float* restrict alpha,
|
||||||
float* restrict a,
|
float* restrict a,
|
||||||
float* restrict b,
|
float* restrict b,
|
||||||
@@ -59,11 +63,14 @@ void bli_sgemm_armsve_asm_2vx10_unindexed
|
|||||||
|
|
||||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||||
// different size than is expected by load instructions.
|
// different size than is expected by load instructions.
|
||||||
uint64_t k_mker = k0 / 4;
|
uint64_t k_mker = k / 4;
|
||||||
uint64_t k_left = k0 % 4;
|
uint64_t k_left = k % 4;
|
||||||
uint64_t rs_c = rs_c0;
|
uint64_t rs_c = rs_c0;
|
||||||
uint64_t cs_c = cs_c0;
|
uint64_t cs_c = cs_c0;
|
||||||
|
|
||||||
|
uint64_t mr = 2*svcntw();
|
||||||
|
GEMM_UKR_SETUP_CT( s, mr, 10, false );
|
||||||
|
|
||||||
__asm__ volatile (
|
__asm__ volatile (
|
||||||
" ldr x0, %[a] \n\t"
|
" ldr x0, %[a] \n\t"
|
||||||
" ldr x1, %[b] \n\t"
|
" ldr x1, %[b] \n\t"
|
||||||
@@ -310,5 +317,7 @@ GEMM_C_STORE_UKER_G(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z30,p0,p0,x5,x7,x8,x
|
|||||||
"z24","z25","z26","z27",
|
"z24","z25","z26","z27",
|
||||||
"z28","z29","z30","z31"
|
"z28","z29","z30","z31"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
GEMM_UKR_FLUSH_CT( s );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -42,9 +42,13 @@
|
|||||||
// 2vx10 microkernels.
|
// 2vx10 microkernels.
|
||||||
#include "armsve_asm_2vx10cmplx.h"
|
#include "armsve_asm_2vx10cmplx.h"
|
||||||
|
|
||||||
|
#include "arm_sve.h"
|
||||||
|
|
||||||
void bli_zgemm_armsve_asm_2vx10_unindexed
|
void bli_zgemm_armsve_asm_2vx10_unindexed
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
dcomplex* restrict alpha,
|
dcomplex* restrict alpha,
|
||||||
dcomplex* restrict a,
|
dcomplex* restrict a,
|
||||||
dcomplex* restrict b,
|
dcomplex* restrict b,
|
||||||
@@ -59,12 +63,15 @@ void bli_zgemm_armsve_asm_2vx10_unindexed
|
|||||||
|
|
||||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||||
// different size than is expected by load instructions.
|
// different size than is expected by load instructions.
|
||||||
uint64_t k_mker = k0 / 4;
|
uint64_t k_mker = k / 4;
|
||||||
uint64_t k_left = k0 % 4;
|
uint64_t k_left = k % 4;
|
||||||
uint64_t rs_c = rs_c0;
|
uint64_t rs_c = rs_c0;
|
||||||
uint64_t cs_c = cs_c0;
|
uint64_t cs_c = cs_c0;
|
||||||
uint64_t info = 0;
|
uint64_t info = 0;
|
||||||
|
|
||||||
|
uint64_t mr = svcntd();
|
||||||
|
GEMM_UKR_SETUP_CT( z, mr, 10, false );
|
||||||
|
|
||||||
__asm__ volatile (
|
__asm__ volatile (
|
||||||
// " ldr x0, %[a] \n\t"
|
// " ldr x0, %[a] \n\t"
|
||||||
// " ldr x1, %[b] \n\t"
|
// " ldr x1, %[b] \n\t"
|
||||||
@@ -309,5 +316,7 @@ GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z28,%2,%4,x16)
|
|||||||
"z24","z25","z26","z27",
|
"z24","z25","z26","z27",
|
||||||
"z28","z29","z30","z31"
|
"z28","z29","z30","z31"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
GEMM_UKR_FLUSH_CT( z );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -42,9 +42,13 @@
|
|||||||
// 2vx7 microkernels.
|
// 2vx7 microkernels.
|
||||||
#include "armsve_asm_2vx7cmplx.h"
|
#include "armsve_asm_2vx7cmplx.h"
|
||||||
|
|
||||||
|
#include "arm_sve.h"
|
||||||
|
|
||||||
void bli_zgemm_armsve_asm_2vx7_unindexed
|
void bli_zgemm_armsve_asm_2vx7_unindexed
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
dcomplex* restrict alpha,
|
dcomplex* restrict alpha,
|
||||||
dcomplex* restrict a,
|
dcomplex* restrict a,
|
||||||
dcomplex* restrict b,
|
dcomplex* restrict b,
|
||||||
@@ -59,12 +63,15 @@ void bli_zgemm_armsve_asm_2vx7_unindexed
|
|||||||
|
|
||||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||||
// different size than is expected by load instructions.
|
// different size than is expected by load instructions.
|
||||||
uint64_t k_mker = k0 / 4;
|
uint64_t k_mker = k / 4;
|
||||||
uint64_t k_left = k0 % 4;
|
uint64_t k_left = k % 4;
|
||||||
uint64_t rs_c = rs_c0;
|
uint64_t rs_c = rs_c0;
|
||||||
uint64_t cs_c = cs_c0;
|
uint64_t cs_c = cs_c0;
|
||||||
uint64_t info = 0;
|
uint64_t info = 0;
|
||||||
|
|
||||||
|
uint64_t mr = svcntd();
|
||||||
|
GEMM_UKR_SETUP_CT( z, mr, 7, false );
|
||||||
|
|
||||||
__asm__ volatile (
|
__asm__ volatile (
|
||||||
// " ldr x0, %[a] \n\t"
|
// " ldr x0, %[a] \n\t"
|
||||||
// " ldr x1, %[b] \n\t"
|
// " ldr x1, %[b] \n\t"
|
||||||
@@ -261,6 +268,8 @@ GEMM_CCMPLX_STORE_COL7_G(z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27
|
|||||||
"z24","z25","z26","z27",
|
"z24","z25","z26","z27",
|
||||||
"z28","z29","z30","z31"
|
"z28","z29","z30","z31"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
GEMM_UKR_FLUSH_CT( z );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -42,9 +42,13 @@
|
|||||||
// 2vx8 microkernels.
|
// 2vx8 microkernels.
|
||||||
#include "armsve_asm_2vx8cmplx.h"
|
#include "armsve_asm_2vx8cmplx.h"
|
||||||
|
|
||||||
|
#include "arm_sve.h"
|
||||||
|
|
||||||
void bli_zgemm_armsve_asm_2vx8_unindexed
|
void bli_zgemm_armsve_asm_2vx8_unindexed
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
dcomplex* restrict alpha,
|
dcomplex* restrict alpha,
|
||||||
dcomplex* restrict a,
|
dcomplex* restrict a,
|
||||||
dcomplex* restrict b,
|
dcomplex* restrict b,
|
||||||
@@ -59,12 +63,15 @@ void bli_zgemm_armsve_asm_2vx8_unindexed
|
|||||||
|
|
||||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||||
// different size than is expected by load instructions.
|
// different size than is expected by load instructions.
|
||||||
uint64_t k_mker = k0 / 6;
|
uint64_t k_mker = k / 6;
|
||||||
uint64_t k_left = k0 % 6;
|
uint64_t k_left = k % 6;
|
||||||
uint64_t rs_c = rs_c0;
|
uint64_t rs_c = rs_c0;
|
||||||
uint64_t cs_c = cs_c0;
|
uint64_t cs_c = cs_c0;
|
||||||
uint64_t info = 0;
|
uint64_t info = 0;
|
||||||
|
|
||||||
|
uint64_t mr = svcntd();
|
||||||
|
GEMM_UKR_SETUP_CT( z, mr, 8, false );
|
||||||
|
|
||||||
__asm__ volatile (
|
__asm__ volatile (
|
||||||
// " ldr x0, %[a] \n\t"
|
// " ldr x0, %[a] \n\t"
|
||||||
// " ldr x1, %[b] \n\t"
|
// " ldr x1, %[b] \n\t"
|
||||||
@@ -286,5 +293,7 @@ GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z16,%2,%4,x16)
|
|||||||
"z24","z25","z26","z27",
|
"z24","z25","z26","z27",
|
||||||
"z28","z29","z30","z31"
|
"z28","z29","z30","z31"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
GEMM_UKR_FLUSH_CT( z );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -48,23 +48,23 @@ void bli_sgemm_armv7a_ker_4x4
|
|||||||
|
|
||||||
void bli_sgemm_armv7a_asm_4x4
|
void bli_sgemm_armv7a_asm_4x4
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
float* restrict alpha,
|
float* restrict alpha,
|
||||||
float* restrict a,
|
float* restrict a,
|
||||||
float* restrict b,
|
float* restrict b,
|
||||||
float* restrict beta,
|
float* restrict beta,
|
||||||
float* restrict c, inc_t rs_c0, inc_t cs_c0,
|
float* restrict c, inc_t rs_c, inc_t cs_c,
|
||||||
auxinfo_t* restrict data,
|
auxinfo_t* restrict data,
|
||||||
cntx_t* restrict cntx
|
cntx_t* restrict cntx
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||||
// different size than is expected by load instructions.
|
// different size than is expected by load instructions.
|
||||||
uint32_t k = k0;
|
GEMM_UKR_SETUP_CT_ANY( s, 4, 4, false );
|
||||||
uint32_t rs_c = rs_c0;
|
|
||||||
uint32_t cs_c = cs_c0;
|
|
||||||
|
|
||||||
bli_sgemm_armv7a_ker_4x4( k, alpha, a, b, beta, c, rs_c, cs_c, data );
|
bli_sgemm_armv7a_ker_4x4( k, alpha, a, b, beta, c, rs_c, cs_c, data );
|
||||||
|
GEMM_UKR_FLUSH_CT( s );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -83,23 +83,23 @@ void bli_dgemm_armv7a_ker_4x4
|
|||||||
|
|
||||||
void bli_dgemm_armv7a_asm_4x4
|
void bli_dgemm_armv7a_asm_4x4
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
double* restrict alpha,
|
double* restrict alpha,
|
||||||
double* restrict a,
|
double* restrict a,
|
||||||
double* restrict b,
|
double* restrict b,
|
||||||
double* restrict beta,
|
double* restrict beta,
|
||||||
double* restrict c, inc_t rs_c0, inc_t cs_c0,
|
double* restrict c, inc_t rs_c, inc_t cs_c,
|
||||||
auxinfo_t* restrict data,
|
auxinfo_t* restrict data,
|
||||||
cntx_t* restrict cntx
|
cntx_t* restrict cntx
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||||
// different size than is expected by load instructions.
|
// different size than is expected by load instructions.
|
||||||
uint32_t k = k0;
|
GEMM_UKR_SETUP_CT_ANY( d, 4, 4, false );
|
||||||
uint32_t rs_c = rs_c0;
|
|
||||||
uint32_t cs_c = cs_c0;
|
|
||||||
|
|
||||||
bli_dgemm_armv7a_ker_4x4( k, alpha, a, b, beta, c, rs_c, cs_c, data );
|
bli_dgemm_armv7a_ker_4x4( k, alpha, a, b, beta, c, rs_c, cs_c, data );
|
||||||
|
GEMM_UKR_FLUSH_CT( d );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -118,23 +118,23 @@ void bli_cgemm_armv7a_ker_2x2
|
|||||||
|
|
||||||
void bli_cgemm_armv7a_asm_2x2
|
void bli_cgemm_armv7a_asm_2x2
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
scomplex* restrict alpha,
|
scomplex* restrict alpha,
|
||||||
scomplex* restrict a,
|
scomplex* restrict a,
|
||||||
scomplex* restrict b,
|
scomplex* restrict b,
|
||||||
scomplex* restrict beta,
|
scomplex* restrict beta,
|
||||||
scomplex* restrict c, inc_t rs_c0, inc_t cs_c0,
|
scomplex* restrict c, inc_t rs_c, inc_t cs_c,
|
||||||
auxinfo_t* restrict data,
|
auxinfo_t* restrict data,
|
||||||
cntx_t* restrict cntx
|
cntx_t* restrict cntx
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||||
// different size than is expected by load instructions.
|
// different size than is expected by load instructions.
|
||||||
uint32_t k = k0;
|
GEMM_UKR_SETUP_CT_ANY( c, 2, 2, false );
|
||||||
uint32_t rs_c = rs_c0;
|
|
||||||
uint32_t cs_c = cs_c0;
|
|
||||||
|
|
||||||
bli_cgemm_armv7a_ker_2x2( k, alpha, a, b, beta, c, rs_c, cs_c, data );
|
bli_cgemm_armv7a_ker_2x2( k, alpha, a, b, beta, c, rs_c, cs_c, data );
|
||||||
|
GEMM_UKR_FLUSH_CT( c );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -153,22 +153,22 @@ void bli_zgemm_armv7a_ker_2x2
|
|||||||
|
|
||||||
void bli_zgemm_armv7a_asm_2x2
|
void bli_zgemm_armv7a_asm_2x2
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
dcomplex* restrict alpha,
|
dcomplex* restrict alpha,
|
||||||
dcomplex* restrict a,
|
dcomplex* restrict a,
|
||||||
dcomplex* restrict b,
|
dcomplex* restrict b,
|
||||||
dcomplex* restrict beta,
|
dcomplex* restrict beta,
|
||||||
dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0,
|
dcomplex* restrict c, inc_t rs_c, inc_t cs_c,
|
||||||
auxinfo_t* restrict data,
|
auxinfo_t* restrict data,
|
||||||
cntx_t* restrict cntx
|
cntx_t* restrict cntx
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||||
// different size than is expected by load instructions.
|
// different size than is expected by load instructions.
|
||||||
uint32_t k = k0;
|
GEMM_UKR_SETUP_CT_ANY( z, 2, 2, false );
|
||||||
uint32_t rs_c = rs_c0;
|
|
||||||
uint32_t cs_c = cs_c0;
|
|
||||||
|
|
||||||
bli_zgemm_armv7a_ker_2x2( k, alpha, a, b, beta, c, rs_c, cs_c, data );
|
bli_zgemm_armv7a_ker_2x2( k, alpha, a, b, beta, c, rs_c, cs_c, data );
|
||||||
|
GEMM_UKR_FLUSH_CT( z );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -37,7 +37,9 @@
|
|||||||
|
|
||||||
void bli_sgemm_armv7a_int_4x4
|
void bli_sgemm_armv7a_int_4x4
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
float* restrict alpha,
|
float* restrict alpha,
|
||||||
float* restrict a,
|
float* restrict a,
|
||||||
float* restrict b,
|
float* restrict b,
|
||||||
@@ -49,12 +51,14 @@ void bli_sgemm_armv7a_int_4x4
|
|||||||
{
|
{
|
||||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||||
// different size than is expected by load instructions.
|
// different size than is expected by load instructions.
|
||||||
uint32_t k_iter = k0 / 4;
|
uint32_t k_iter = k / 4;
|
||||||
uint32_t k_left = k0 % 4;
|
uint32_t k_left = k % 4;
|
||||||
uint32_t rs_c = rs_c0;
|
uint32_t rs_c = rs_c0;
|
||||||
uint32_t cs_c = cs_c0;
|
uint32_t cs_c = cs_c0;
|
||||||
uint32_t i;
|
uint32_t i;
|
||||||
|
|
||||||
|
GEMM_UKR_SETUP_CT( s, 4, 4, false );
|
||||||
|
|
||||||
void* a_next = bli_auxinfo_next_a( data );
|
void* a_next = bli_auxinfo_next_a( data );
|
||||||
void* b_next = bli_auxinfo_next_b( data );
|
void* b_next = bli_auxinfo_next_b( data );
|
||||||
|
|
||||||
@@ -82,47 +86,17 @@ void bli_sgemm_armv7a_int_4x4
|
|||||||
|
|
||||||
if ( *beta != 0.0F )
|
if ( *beta != 0.0F )
|
||||||
{
|
{
|
||||||
if ( rs_c == 1 )
|
// Load column 0
|
||||||
{
|
cv0 = vld1q_f32( c + 0*cs_c );
|
||||||
// Load column 0
|
|
||||||
cv0 = vld1q_f32( c + 0*rs_c + 0*cs_c );
|
|
||||||
|
|
||||||
// Load column 1
|
// Load column 1
|
||||||
cv1 = vld1q_f32( c + 0*rs_c + 1*cs_c );
|
cv1 = vld1q_f32( c + 1*cs_c );
|
||||||
|
|
||||||
// Load column 2
|
// Load column 2
|
||||||
cv2 = vld1q_f32( c + 0*rs_c + 2*cs_c );
|
cv2 = vld1q_f32( c + 2*cs_c );
|
||||||
|
|
||||||
// Load column 3
|
// Load column 3
|
||||||
cv3 = vld1q_f32( c + 0*rs_c + 3*cs_c );
|
cv3 = vld1q_f32( c + 3*cs_c );
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
// Load column 0
|
|
||||||
cv0 = vld1q_lane_f32( c + 0*rs_c + 0*cs_c, cv0, 0);
|
|
||||||
cv0 = vld1q_lane_f32( c + 1*rs_c + 0*cs_c, cv0, 1);
|
|
||||||
cv0 = vld1q_lane_f32( c + 2*rs_c + 0*cs_c, cv0, 2);
|
|
||||||
cv0 = vld1q_lane_f32( c + 3*rs_c + 0*cs_c, cv0, 3);
|
|
||||||
|
|
||||||
// Load column 1
|
|
||||||
cv1 = vld1q_lane_f32( c + 0*rs_c + 1*cs_c, cv1, 0);
|
|
||||||
cv1 = vld1q_lane_f32( c + 1*rs_c + 1*cs_c, cv1, 1);
|
|
||||||
cv1 = vld1q_lane_f32( c + 2*rs_c + 1*cs_c, cv1, 2);
|
|
||||||
cv1 = vld1q_lane_f32( c + 3*rs_c + 1*cs_c, cv1, 3);
|
|
||||||
|
|
||||||
// Load column 2
|
|
||||||
cv2 = vld1q_lane_f32( c + 0*rs_c + 2*cs_c, cv2, 0);
|
|
||||||
cv2 = vld1q_lane_f32( c + 1*rs_c + 2*cs_c, cv2, 1);
|
|
||||||
cv2 = vld1q_lane_f32( c + 2*rs_c + 2*cs_c, cv2, 2);
|
|
||||||
cv2 = vld1q_lane_f32( c + 3*rs_c + 2*cs_c, cv2, 3);
|
|
||||||
|
|
||||||
// Load column 3
|
|
||||||
cv3 = vld1q_lane_f32( c + 0*rs_c + 3*cs_c, cv3, 0);
|
|
||||||
cv3 = vld1q_lane_f32( c + 1*rs_c + 3*cs_c, cv3, 1);
|
|
||||||
cv3 = vld1q_lane_f32( c + 2*rs_c + 3*cs_c, cv3, 2);
|
|
||||||
cv3 = vld1q_lane_f32( c + 3*rs_c + 3*cs_c, cv3, 3);
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@@ -255,47 +229,22 @@ void bli_sgemm_armv7a_int_4x4
|
|||||||
cv3 = vmlaq_f32( cv3, abv3, alphav );
|
cv3 = vmlaq_f32( cv3, abv3, alphav );
|
||||||
}
|
}
|
||||||
|
|
||||||
if ( rs_c == 1 )
|
// Store column 0
|
||||||
{
|
vst1q_f32( c + 0*cs_c, cv0 );
|
||||||
// Store column 0
|
// Store column 1
|
||||||
vst1q_f32( c + 0*rs_c + 0*cs_c, cv0 );
|
vst1q_f32( c + 1*cs_c, cv1 );
|
||||||
// Store column 1
|
// Store column 2
|
||||||
vst1q_f32( c + 0*rs_c + 1*cs_c, cv1 );
|
vst1q_f32( c + 2*cs_c, cv2 );
|
||||||
// Store column 2
|
// Store column 3
|
||||||
vst1q_f32( c + 0*rs_c + 2*cs_c, cv2 );
|
vst1q_f32( c + 3*cs_c, cv3 );
|
||||||
// Store column 3
|
|
||||||
vst1q_f32( c + 0*rs_c + 3*cs_c, cv3 );
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
// Store column 0
|
|
||||||
vst1q_lane_f32( c + 0*rs_c + 0*cs_c, cv0, 0);
|
|
||||||
vst1q_lane_f32( c + 1*rs_c + 0*cs_c, cv0, 1);
|
|
||||||
vst1q_lane_f32( c + 2*rs_c + 0*cs_c, cv0, 2);
|
|
||||||
vst1q_lane_f32( c + 3*rs_c + 0*cs_c, cv0, 3);
|
|
||||||
|
|
||||||
// Store column 1
|
GEMM_UKR_FLUSH_CT( s );
|
||||||
vst1q_lane_f32( c + 0*rs_c + 1*cs_c, cv1, 0);
|
|
||||||
vst1q_lane_f32( c + 1*rs_c + 1*cs_c, cv1, 1);
|
|
||||||
vst1q_lane_f32( c + 2*rs_c + 1*cs_c, cv1, 2);
|
|
||||||
vst1q_lane_f32( c + 3*rs_c + 1*cs_c, cv1, 3);
|
|
||||||
|
|
||||||
// Store column 2
|
|
||||||
vst1q_lane_f32( c + 0*rs_c + 2*cs_c, cv2, 0);
|
|
||||||
vst1q_lane_f32( c + 1*rs_c + 2*cs_c, cv2, 1);
|
|
||||||
vst1q_lane_f32( c + 2*rs_c + 2*cs_c, cv2, 2);
|
|
||||||
vst1q_lane_f32( c + 3*rs_c + 2*cs_c, cv2, 3);
|
|
||||||
|
|
||||||
// Store column 3
|
|
||||||
vst1q_lane_f32( c + 0*rs_c + 3*cs_c, cv3, 0);
|
|
||||||
vst1q_lane_f32( c + 1*rs_c + 3*cs_c, cv3, 1);
|
|
||||||
vst1q_lane_f32( c + 2*rs_c + 3*cs_c, cv3, 2);
|
|
||||||
vst1q_lane_f32( c + 3*rs_c + 3*cs_c, cv3, 3);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void bli_dgemm_armv7a_int_4x4
|
void bli_dgemm_armv7a_int_4x4
|
||||||
(
|
(
|
||||||
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
dim_t k,
|
dim_t k,
|
||||||
double* restrict alpha,
|
double* restrict alpha,
|
||||||
double* restrict a,
|
double* restrict a,
|
||||||
@@ -314,6 +263,8 @@ void bli_dgemm_armv7a_int_4x4
|
|||||||
uint32_t cs_c = cs_c0;
|
uint32_t cs_c = cs_c0;
|
||||||
uint32_t i;
|
uint32_t i;
|
||||||
|
|
||||||
|
GEMM_UKR_SETUP_CT_ANY( d, 4, 4, false );
|
||||||
|
|
||||||
//void* a_next = bli_auxinfo_next_a( data );
|
//void* a_next = bli_auxinfo_next_a( data );
|
||||||
//void* b_next = bli_auxinfo_next_b( data );
|
//void* b_next = bli_auxinfo_next_b( data );
|
||||||
|
|
||||||
@@ -568,5 +519,7 @@ void bli_dgemm_armv7a_int_4x4
|
|||||||
*c23 += ab23 * *alpha;
|
*c23 += ab23 * *alpha;
|
||||||
*c33 += ab33 * *alpha;
|
*c33 += ab33 * *alpha;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GEMM_UKR_FLUSH_CT( d );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -56,6 +56,8 @@
|
|||||||
|
|
||||||
void bli_dgemm_bgq_int_8x8
|
void bli_dgemm_bgq_int_8x8
|
||||||
(
|
(
|
||||||
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
dim_t k,
|
dim_t k,
|
||||||
double* restrict alpha,
|
double* restrict alpha,
|
||||||
double* restrict a,
|
double* restrict a,
|
||||||
@@ -66,6 +68,8 @@ void bli_dgemm_bgq_int_8x8
|
|||||||
cntx_t* restrict cntx
|
cntx_t* restrict cntx
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
|
GEMM_UKR_SETUP_CT_ANY( d, 8, 8, false );
|
||||||
|
|
||||||
//Registers for storing C.
|
//Registers for storing C.
|
||||||
//4 4x4 subblocks of C, c00, c01, c10, c11
|
//4 4x4 subblocks of C, c00, c01, c10, c11
|
||||||
//4 registers per subblock: a, b, c, d
|
//4 registers per subblock: a, b, c, d
|
||||||
@@ -201,6 +205,8 @@ void bli_dgemm_bgq_int_8x8
|
|||||||
UPDATE( AB, c, 0 );
|
UPDATE( AB, c, 0 );
|
||||||
AB = vec_perm( c11d, c11d, pattern );
|
AB = vec_perm( c11d, c11d, pattern );
|
||||||
UPDATE( AB, c, 4 );
|
UPDATE( AB, c, 4 );
|
||||||
|
|
||||||
|
GEMM_UKR_FLUSH_CT( d );
|
||||||
}
|
}
|
||||||
|
|
||||||
void printvec(vector4double v)
|
void printvec(vector4double v)
|
||||||
@@ -214,6 +220,8 @@ void printvec(vector4double v)
|
|||||||
|
|
||||||
void bli_zgemm_bgq_int_4x4
|
void bli_zgemm_bgq_int_4x4
|
||||||
(
|
(
|
||||||
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
dim_t k,
|
dim_t k,
|
||||||
dcomplex* restrict alpha,
|
dcomplex* restrict alpha,
|
||||||
dcomplex* restrict a,
|
dcomplex* restrict a,
|
||||||
@@ -224,6 +232,8 @@ void bli_zgemm_bgq_int_4x4
|
|||||||
cntx_t* restrict cntx
|
cntx_t* restrict cntx
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
|
GEMM_UKR_SETUP_CT_ANY( z, 4, 4, false );
|
||||||
|
|
||||||
double* a_d = ( double* )a;
|
double* a_d = ( double* )a;
|
||||||
double* b_d = ( double* )b;
|
double* b_d = ( double* )b;
|
||||||
double* c_d = ( double* )c;
|
double* c_d = ( double* )c;
|
||||||
@@ -368,4 +378,6 @@ void bli_zgemm_bgq_int_4x4
|
|||||||
c_d += 2*cs_c;
|
c_d += 2*cs_c;
|
||||||
ZUPDATE( c03a, c03b, c_d, 0 );
|
ZUPDATE( c03a, c03b, c_d, 0 );
|
||||||
ZUPDATE( c13a, c13b, c_d, 4 );
|
ZUPDATE( c13a, c13b, c_d, 4 );
|
||||||
|
|
||||||
|
GEMM_UKR_FLUSH_CT( z );
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -256,6 +256,8 @@ extern int offsets[16];
|
|||||||
//#define LOOPMON
|
//#define LOOPMON
|
||||||
void bli_dgemm_knc_asm_30x8
|
void bli_dgemm_knc_asm_30x8
|
||||||
(
|
(
|
||||||
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
dim_t k,
|
dim_t k,
|
||||||
double* restrict alpha,
|
double* restrict alpha,
|
||||||
double* restrict a,
|
double* restrict a,
|
||||||
@@ -273,80 +275,82 @@ void bli_dgemm_knc_asm_30x8
|
|||||||
|
|
||||||
uint64_t k64 = k;
|
uint64_t k64 = k;
|
||||||
|
|
||||||
|
GEMM_UKR_SETUP_CT( d, 30, 8, true );
|
||||||
|
|
||||||
#ifdef MONITORS
|
#ifdef MONITORS
|
||||||
int toph, topl, both, botl, midl, midh, mid2l, mid2h;
|
int toph, topl, both, botl, midl, midh, mid2l, mid2h;
|
||||||
#endif
|
#endif
|
||||||
#ifdef LOOPMON
|
#ifdef LOOPMON
|
||||||
int tlooph, tloopl, blooph, bloopl;
|
int tlooph, tloopl, blooph, bloopl;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
__asm
|
__asm
|
||||||
{
|
{
|
||||||
#ifdef MONITORS
|
#ifdef MONITORS
|
||||||
rdtsc
|
rdtsc
|
||||||
mov topl, eax
|
mov topl, eax
|
||||||
mov toph, edx
|
mov toph, edx
|
||||||
#endif
|
#endif
|
||||||
vpxord zmm0, zmm0, zmm0
|
vpxord zmm0, zmm0, zmm0
|
||||||
vmovaps zmm1, zmm0 //clear out registers
|
vmovaps zmm1, zmm0 //clear out registers
|
||||||
vmovaps zmm2, zmm0
|
vmovaps zmm2, zmm0
|
||||||
mov rsi, k64 //loop index
|
mov rsi, k64 //loop index
|
||||||
vmovaps zmm3, zmm0
|
vmovaps zmm3, zmm0
|
||||||
|
|
||||||
mov r11, rs_c //load row stride
|
mov r11, rs_c //load row stride
|
||||||
vmovaps zmm4, zmm0
|
vmovaps zmm4, zmm0
|
||||||
sal r11, 3 //scale row stride
|
sal r11, 3 //scale row stride
|
||||||
vmovaps zmm5, zmm0
|
vmovaps zmm5, zmm0
|
||||||
mov r15, a //load address of a
|
mov r15, a //load address of a
|
||||||
vmovaps zmm6, zmm0
|
vmovaps zmm6, zmm0
|
||||||
mov rbx, b //load address of b
|
mov rbx, b //load address of b
|
||||||
vmovaps zmm7, zmm0
|
vmovaps zmm7, zmm0
|
||||||
|
|
||||||
vmovaps zmm8, zmm0
|
vmovaps zmm8, zmm0
|
||||||
lea r10, [r11 + 2*r11 + 0] //r10 has 3 * r11
|
lea r10, [r11 + 2*r11 + 0] //r10 has 3 * r11
|
||||||
vmovaps zmm9, zmm0
|
vmovaps zmm9, zmm0
|
||||||
vmovaps zmm10, zmm0
|
vmovaps zmm10, zmm0
|
||||||
mov rdi, r11
|
mov rdi, r11
|
||||||
vmovaps zmm11, zmm0
|
vmovaps zmm11, zmm0
|
||||||
sal rdi, 2 //rdi has 4*r11
|
sal rdi, 2 //rdi has 4*r11
|
||||||
|
|
||||||
vmovaps zmm12, zmm0
|
vmovaps zmm12, zmm0
|
||||||
mov rcx, c //load address of c for prefetching
|
mov rcx, c //load address of c for prefetching
|
||||||
vmovaps zmm13, zmm0
|
vmovaps zmm13, zmm0
|
||||||
vmovaps zmm14, zmm0
|
vmovaps zmm14, zmm0
|
||||||
mov r8, k64
|
mov r8, k64
|
||||||
vmovaps zmm15, zmm0
|
vmovaps zmm15, zmm0
|
||||||
|
|
||||||
vmovaps zmm16, zmm0
|
vmovaps zmm16, zmm0
|
||||||
vmovaps zmm17, zmm0
|
vmovaps zmm17, zmm0
|
||||||
mov r13, L2_PREFETCH_DIST*8*8
|
mov r13, L2_PREFETCH_DIST*8*8
|
||||||
vmovaps zmm18, zmm0
|
vmovaps zmm18, zmm0
|
||||||
mov r14, L2_PREFETCH_DIST*8*32
|
mov r14, L2_PREFETCH_DIST*8*32
|
||||||
vmovaps zmm19, zmm0
|
vmovaps zmm19, zmm0
|
||||||
vmovaps zmm20, zmm0
|
vmovaps zmm20, zmm0
|
||||||
vmovaps zmm21, zmm0
|
vmovaps zmm21, zmm0
|
||||||
vmovaps zmm22, zmm0
|
vmovaps zmm22, zmm0
|
||||||
|
|
||||||
vmovaps zmm23, zmm0
|
vmovaps zmm23, zmm0
|
||||||
sub r8, 30 + L2_PREFETCH_DIST //Check if we have over 40 operations to do.
|
sub r8, 30 + L2_PREFETCH_DIST //Check if we have over 40 operations to do.
|
||||||
vmovaps zmm24, zmm0
|
vmovaps zmm24, zmm0
|
||||||
mov r8, 30
|
mov r8, 30
|
||||||
vmovaps zmm25, zmm0
|
vmovaps zmm25, zmm0
|
||||||
mov r9, 8*8 //amount to increment b* by each iteration
|
mov r9, 8*8 //amount to increment b* by each iteration
|
||||||
vmovaps zmm26, zmm0
|
vmovaps zmm26, zmm0
|
||||||
mov r12, 32*8 //amount to increment a* by each iteration
|
mov r12, 32*8 //amount to increment a* by each iteration
|
||||||
vmovaps zmm27, zmm0
|
vmovaps zmm27, zmm0
|
||||||
vmovaps zmm28, zmm0
|
vmovaps zmm28, zmm0
|
||||||
vmovaps zmm29, zmm0
|
vmovaps zmm29, zmm0
|
||||||
|
|
||||||
#ifdef MONITORS
|
#ifdef MONITORS
|
||||||
rdtsc
|
rdtsc
|
||||||
mov midl, eax
|
mov midl, eax
|
||||||
mov midh, edx
|
mov midh, edx
|
||||||
#endif
|
#endif
|
||||||
jle CONSIDER_UNDER_40
|
jle CONSIDER_UNDER_40
|
||||||
sub rsi, 30 + L2_PREFETCH_DIST
|
sub rsi, 30 + L2_PREFETCH_DIST
|
||||||
|
|
||||||
//First 30 iterations
|
//First 30 iterations
|
||||||
LOOPREFECHCL2:
|
LOOPREFECHCL2:
|
||||||
ONE_ITER_PC_L2(rcx)
|
ONE_ITER_PC_L2(rcx)
|
||||||
@@ -357,26 +361,26 @@ void bli_dgemm_knc_asm_30x8
|
|||||||
LOOPMAIN:
|
LOOPMAIN:
|
||||||
ONE_ITER_MAIN_LOOP(rcx, rsi)
|
ONE_ITER_MAIN_LOOP(rcx, rsi)
|
||||||
jne LOOPMAIN
|
jne LOOPMAIN
|
||||||
|
|
||||||
//Penultimate 22 iterations.
|
//Penultimate 22 iterations.
|
||||||
//Break these off from the main loop to avoid prefetching extra shit.
|
//Break these off from the main loop to avoid prefetching extra shit.
|
||||||
mov r14, a_next
|
mov r14, a_next
|
||||||
mov r13, b_next
|
mov r13, b_next
|
||||||
sub r14, r15
|
sub r14, r15
|
||||||
sub r13, rbx
|
sub r13, rbx
|
||||||
|
|
||||||
mov rsi, L2_PREFETCH_DIST-10
|
mov rsi, L2_PREFETCH_DIST-10
|
||||||
LOOPMAIN2:
|
LOOPMAIN2:
|
||||||
ONE_ITER_MAIN_LOOP(rcx, rsi)
|
ONE_ITER_MAIN_LOOP(rcx, rsi)
|
||||||
jne LOOPMAIN2
|
jne LOOPMAIN2
|
||||||
|
|
||||||
|
|
||||||
//Last 10 iterations
|
//Last 10 iterations
|
||||||
mov r8, 10
|
mov r8, 10
|
||||||
LOOPREFETCHCL1:
|
LOOPREFETCHCL1:
|
||||||
ONE_ITER_PC_L1(rcx)
|
ONE_ITER_PC_L1(rcx)
|
||||||
jne LOOPREFETCHCL1
|
jne LOOPREFETCHCL1
|
||||||
|
|
||||||
|
|
||||||
jmp POSTACCUM
|
jmp POSTACCUM
|
||||||
|
|
||||||
@@ -403,14 +407,8 @@ void bli_dgemm_knc_asm_30x8
|
|||||||
mov r9, c //load address of c for update
|
mov r9, c //load address of c for update
|
||||||
mov r12, alpha //load address of alpha
|
mov r12, alpha //load address of alpha
|
||||||
|
|
||||||
// Check if C is row stride. If not, jump to the slow scattered update
|
|
||||||
mov r14, cs_c
|
|
||||||
dec r14
|
|
||||||
jne SCATTEREDUPDATE
|
|
||||||
|
|
||||||
mov r14, beta
|
mov r14, beta
|
||||||
vbroadcastsd zmm31, 0[r14]
|
vbroadcastsd zmm31, 0[r14]
|
||||||
|
|
||||||
|
|
||||||
vmulpd zmm0, zmm0, 0[r12]{1to8}
|
vmulpd zmm0, zmm0, 0[r12]{1to8}
|
||||||
vmulpd zmm1, zmm1, 0[r12]{1to8}
|
vmulpd zmm1, zmm1, 0[r12]{1to8}
|
||||||
@@ -467,7 +465,7 @@ void bli_dgemm_knc_asm_30x8
|
|||||||
vmovapd [r9+2*r11+0], zmm14
|
vmovapd [r9+2*r11+0], zmm14
|
||||||
vmovapd [r9+r10+0], zmm15
|
vmovapd [r9+r10+0], zmm15
|
||||||
add r9, rdi
|
add r9, rdi
|
||||||
|
|
||||||
vmulpd zmm16, zmm16, 0[r12]{1to8}
|
vmulpd zmm16, zmm16, 0[r12]{1to8}
|
||||||
vmulpd zmm17, zmm17, 0[r12]{1to8}
|
vmulpd zmm17, zmm17, 0[r12]{1to8}
|
||||||
vmulpd zmm18, zmm18, 0[r12]{1to8}
|
vmulpd zmm18, zmm18, 0[r12]{1to8}
|
||||||
@@ -516,47 +514,6 @@ void bli_dgemm_knc_asm_30x8
|
|||||||
vfmadd231pd zmm29, zmm31, [r9+r11+0]
|
vfmadd231pd zmm29, zmm31, [r9+r11+0]
|
||||||
vmovapd [r9+0], zmm28
|
vmovapd [r9+0], zmm28
|
||||||
vmovapd [r9+r11+0], zmm29
|
vmovapd [r9+r11+0], zmm29
|
||||||
|
|
||||||
jmp END
|
|
||||||
|
|
||||||
SCATTEREDUPDATE:
|
|
||||||
mov r10, offsetPtr
|
|
||||||
vmovapd zmm31, 0[r10]
|
|
||||||
vpbroadcastd zmm30, cs_c
|
|
||||||
mov r13, beta
|
|
||||||
vpmulld zmm30, zmm31, zmm30
|
|
||||||
|
|
||||||
mov ebx, 255
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm0, 0, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm1, 1, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm2, 2, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm3, 3, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm4, 4, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm5, 5, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm6, 6, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm7, 7, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm8, 8, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm9, 9, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm10, 10, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm11, 11, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm12, 12, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm13, 13, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm14, 14, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm15, 15, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm16, 16, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm17, 17, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm18, 18, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm19, 19, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm20, 20, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm21, 21, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm22, 22, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm23, 23, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm24, 24, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm25, 25, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm26, 26, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm27, 27, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm28, 28, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm29, 29, r9)
|
|
||||||
|
|
||||||
END:
|
END:
|
||||||
#ifdef MONITORS
|
#ifdef MONITORS
|
||||||
@@ -566,6 +523,8 @@ void bli_dgemm_knc_asm_30x8
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GEMM_UKR_FLUSH_CT( d );
|
||||||
|
|
||||||
#ifdef LOOPMON
|
#ifdef LOOPMON
|
||||||
printf("looptime = \t%d\n", bloopl - tloopl);
|
printf("looptime = \t%d\n", bloopl - tloopl);
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -256,6 +256,8 @@ int offsets[16] __attribute__((aligned(0x1000))) = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9
|
|||||||
//#define LOOPMON
|
//#define LOOPMON
|
||||||
void bli_sgemm_knc_asm_30x16
|
void bli_sgemm_knc_asm_30x16
|
||||||
(
|
(
|
||||||
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
dim_t k,
|
dim_t k,
|
||||||
float* restrict alpha,
|
float* restrict alpha,
|
||||||
float* restrict a,
|
float* restrict a,
|
||||||
@@ -273,80 +275,82 @@ void bli_sgemm_knc_asm_30x16
|
|||||||
|
|
||||||
uint64_t k64 = k;
|
uint64_t k64 = k;
|
||||||
|
|
||||||
|
GEMM_UKR_SETUP_CT( s, 30, 16, true );
|
||||||
|
|
||||||
#ifdef MONITORS
|
#ifdef MONITORS
|
||||||
int toph, topl, both, botl, midl, midh, mid2l, mid2h;
|
int toph, topl, both, botl, midl, midh, mid2l, mid2h;
|
||||||
#endif
|
#endif
|
||||||
#ifdef LOOPMON
|
#ifdef LOOPMON
|
||||||
int tlooph, tloopl, blooph, bloopl;
|
int tlooph, tloopl, blooph, bloopl;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
__asm
|
__asm
|
||||||
{
|
{
|
||||||
#ifdef MONITORS
|
#ifdef MONITORS
|
||||||
rdtsc
|
rdtsc
|
||||||
mov topl, eax
|
mov topl, eax
|
||||||
mov toph, edx
|
mov toph, edx
|
||||||
#endif
|
#endif
|
||||||
vpxord zmm0, zmm0, zmm0
|
vpxord zmm0, zmm0, zmm0
|
||||||
vmovaps zmm1, zmm0 //clear out registers
|
vmovaps zmm1, zmm0 //clear out registers
|
||||||
vmovaps zmm2, zmm0
|
vmovaps zmm2, zmm0
|
||||||
mov rsi, k64 //loop index
|
mov rsi, k64 //loop index
|
||||||
vmovaps zmm3, zmm0
|
vmovaps zmm3, zmm0
|
||||||
|
|
||||||
mov r11, rs_c //load row stride
|
mov r11, rs_c //load row stride
|
||||||
vmovaps zmm4, zmm0
|
vmovaps zmm4, zmm0
|
||||||
sal r11, 2 //scale row stride
|
sal r11, 2 //scale row stride
|
||||||
vmovaps zmm5, zmm0
|
vmovaps zmm5, zmm0
|
||||||
mov r15, a //load address of a
|
mov r15, a //load address of a
|
||||||
vmovaps zmm6, zmm0
|
vmovaps zmm6, zmm0
|
||||||
mov rbx, b //load address of b
|
mov rbx, b //load address of b
|
||||||
vmovaps zmm7, zmm0
|
vmovaps zmm7, zmm0
|
||||||
|
|
||||||
vmovaps zmm8, zmm0
|
vmovaps zmm8, zmm0
|
||||||
lea r10, [r11 + 2*r11 + 0] //r10 has 3 * r11
|
lea r10, [r11 + 2*r11 + 0] //r10 has 3 * r11
|
||||||
vmovaps zmm9, zmm0
|
vmovaps zmm9, zmm0
|
||||||
vmovaps zmm10, zmm0
|
vmovaps zmm10, zmm0
|
||||||
mov rdi, r11
|
mov rdi, r11
|
||||||
vmovaps zmm11, zmm0
|
vmovaps zmm11, zmm0
|
||||||
sal rdi, 2 //rdi has 4*r11
|
sal rdi, 2 //rdi has 4*r11
|
||||||
|
|
||||||
vmovaps zmm12, zmm0
|
vmovaps zmm12, zmm0
|
||||||
mov rcx, c //load address of c for prefetching
|
mov rcx, c //load address of c for prefetching
|
||||||
vmovaps zmm13, zmm0
|
vmovaps zmm13, zmm0
|
||||||
vmovaps zmm14, zmm0
|
vmovaps zmm14, zmm0
|
||||||
mov r8, k64
|
mov r8, k64
|
||||||
vmovaps zmm15, zmm0
|
vmovaps zmm15, zmm0
|
||||||
|
|
||||||
vmovaps zmm16, zmm0
|
vmovaps zmm16, zmm0
|
||||||
vmovaps zmm17, zmm0
|
vmovaps zmm17, zmm0
|
||||||
mov r13, L2_PREFETCH_DIST*4*16
|
mov r13, L2_PREFETCH_DIST*4*16
|
||||||
vmovaps zmm18, zmm0
|
vmovaps zmm18, zmm0
|
||||||
mov r14, L2_PREFETCH_DIST*4*32
|
mov r14, L2_PREFETCH_DIST*4*32
|
||||||
vmovaps zmm19, zmm0
|
vmovaps zmm19, zmm0
|
||||||
vmovaps zmm20, zmm0
|
vmovaps zmm20, zmm0
|
||||||
vmovaps zmm21, zmm0
|
vmovaps zmm21, zmm0
|
||||||
vmovaps zmm22, zmm0
|
vmovaps zmm22, zmm0
|
||||||
|
|
||||||
vmovaps zmm23, zmm0
|
vmovaps zmm23, zmm0
|
||||||
sub r8, 30 + L2_PREFETCH_DIST //Check if we have over 40 operations to do.
|
sub r8, 30 + L2_PREFETCH_DIST //Check if we have over 40 operations to do.
|
||||||
vmovaps zmm24, zmm0
|
vmovaps zmm24, zmm0
|
||||||
mov r8, 30
|
mov r8, 30
|
||||||
vmovaps zmm25, zmm0
|
vmovaps zmm25, zmm0
|
||||||
mov r9, 16*4 //amount to increment b* by each iteration
|
mov r9, 16*4 //amount to increment b* by each iteration
|
||||||
vmovaps zmm26, zmm0
|
vmovaps zmm26, zmm0
|
||||||
mov r12, 32*4 //amount to increment a* by each iteration
|
mov r12, 32*4 //amount to increment a* by each iteration
|
||||||
vmovaps zmm27, zmm0
|
vmovaps zmm27, zmm0
|
||||||
vmovaps zmm28, zmm0
|
vmovaps zmm28, zmm0
|
||||||
vmovaps zmm29, zmm0
|
vmovaps zmm29, zmm0
|
||||||
|
|
||||||
#ifdef MONITORS
|
#ifdef MONITORS
|
||||||
rdtsc
|
rdtsc
|
||||||
mov midl, eax
|
mov midl, eax
|
||||||
mov midh, edx
|
mov midh, edx
|
||||||
#endif
|
#endif
|
||||||
jle CONSIDER_UNDER_40
|
jle CONSIDER_UNDER_40
|
||||||
sub rsi, 30 + L2_PREFETCH_DIST
|
sub rsi, 30 + L2_PREFETCH_DIST
|
||||||
|
|
||||||
//First 30 iterations
|
//First 30 iterations
|
||||||
LOOPREFECHCL2:
|
LOOPREFECHCL2:
|
||||||
ONE_ITER_PC_L2(rcx)
|
ONE_ITER_PC_L2(rcx)
|
||||||
@@ -357,26 +361,26 @@ void bli_sgemm_knc_asm_30x16
|
|||||||
LOOPMAIN:
|
LOOPMAIN:
|
||||||
ONE_ITER_MAIN_LOOP(rcx, rsi)
|
ONE_ITER_MAIN_LOOP(rcx, rsi)
|
||||||
jne LOOPMAIN
|
jne LOOPMAIN
|
||||||
|
|
||||||
//Penultimate 22 iterations.
|
//Penultimate 22 iterations.
|
||||||
//Break these off from the main loop to avoid prefetching extra shit.
|
//Break these off from the main loop to avoid prefetching extra shit.
|
||||||
mov r14, a_next
|
mov r14, a_next
|
||||||
mov r13, b_next
|
mov r13, b_next
|
||||||
sub r14, r15
|
sub r14, r15
|
||||||
sub r13, rbx
|
sub r13, rbx
|
||||||
|
|
||||||
mov rsi, L2_PREFETCH_DIST-10
|
mov rsi, L2_PREFETCH_DIST-10
|
||||||
LOOPMAIN2:
|
LOOPMAIN2:
|
||||||
ONE_ITER_MAIN_LOOP(rcx, rsi)
|
ONE_ITER_MAIN_LOOP(rcx, rsi)
|
||||||
jne LOOPMAIN2
|
jne LOOPMAIN2
|
||||||
|
|
||||||
|
|
||||||
//Last 10 iterations
|
//Last 10 iterations
|
||||||
mov r8, 10
|
mov r8, 10
|
||||||
LOOPREFETCHCL1:
|
LOOPREFETCHCL1:
|
||||||
ONE_ITER_PC_L1(rcx)
|
ONE_ITER_PC_L1(rcx)
|
||||||
jne LOOPREFETCHCL1
|
jne LOOPREFETCHCL1
|
||||||
|
|
||||||
|
|
||||||
jmp POSTACCUM
|
jmp POSTACCUM
|
||||||
|
|
||||||
@@ -384,7 +388,7 @@ void bli_sgemm_knc_asm_30x16
|
|||||||
//Used when <= 40 iterations
|
//Used when <= 40 iterations
|
||||||
CONSIDER_UNDER_40:
|
CONSIDER_UNDER_40:
|
||||||
mov rsi, k64
|
mov rsi, k64
|
||||||
test rsi, rsi
|
test rsi, rsi
|
||||||
je POSTACCUM
|
je POSTACCUM
|
||||||
LOOP_UNDER_40:
|
LOOP_UNDER_40:
|
||||||
ONE_ITER_MAIN_LOOP(rcx, rsi)
|
ONE_ITER_MAIN_LOOP(rcx, rsi)
|
||||||
@@ -403,13 +407,8 @@ void bli_sgemm_knc_asm_30x16
|
|||||||
mov r9, c //load address of c for update
|
mov r9, c //load address of c for update
|
||||||
mov r12, alpha //load address of alpha
|
mov r12, alpha //load address of alpha
|
||||||
|
|
||||||
// Check if C is row stride. If not, jump to the slow scattered update
|
|
||||||
mov r14, cs_c
|
|
||||||
dec r14
|
|
||||||
jne SCATTEREDUPDATE
|
|
||||||
|
|
||||||
mov r14, beta
|
mov r14, beta
|
||||||
vbroadcastss zmm31, 0[r14]
|
vbroadcastss zmm31, 0[r14]
|
||||||
|
|
||||||
|
|
||||||
vmulps zmm0, zmm0, 0[r12]{1to16}
|
vmulps zmm0, zmm0, 0[r12]{1to16}
|
||||||
@@ -467,7 +466,7 @@ void bli_sgemm_knc_asm_30x16
|
|||||||
vmovaps [r9+2*r11+0], zmm14
|
vmovaps [r9+2*r11+0], zmm14
|
||||||
vmovaps [r9+r10+0], zmm15
|
vmovaps [r9+r10+0], zmm15
|
||||||
add r9, rdi
|
add r9, rdi
|
||||||
|
|
||||||
vmulps zmm16, zmm16, 0[r12]{1to16}
|
vmulps zmm16, zmm16, 0[r12]{1to16}
|
||||||
vmulps zmm17, zmm17, 0[r12]{1to16}
|
vmulps zmm17, zmm17, 0[r12]{1to16}
|
||||||
vmulps zmm18, zmm18, 0[r12]{1to16}
|
vmulps zmm18, zmm18, 0[r12]{1to16}
|
||||||
@@ -516,48 +515,6 @@ void bli_sgemm_knc_asm_30x16
|
|||||||
vfmadd231ps zmm29, zmm31, [r9+r11+0]
|
vfmadd231ps zmm29, zmm31, [r9+r11+0]
|
||||||
vmovaps [r9+0], zmm28
|
vmovaps [r9+0], zmm28
|
||||||
vmovaps [r9+r11+0], zmm29
|
vmovaps [r9+r11+0], zmm29
|
||||||
|
|
||||||
jmp END
|
|
||||||
|
|
||||||
SCATTEREDUPDATE:
|
|
||||||
|
|
||||||
mov r10, offsetPtr
|
|
||||||
vmovaps zmm31, 0[r10]
|
|
||||||
vpbroadcastd zmm30, cs_c
|
|
||||||
mov r13, beta
|
|
||||||
vpmulld zmm30, zmm31, zmm30
|
|
||||||
|
|
||||||
mov ebx, 0xFFFF
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm0, 0, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm1, 1, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm2, 2, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm3, 3, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm4, 4, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm5, 5, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm6, 6, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm7, 7, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm8, 8, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm9, 9, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm10, 10, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm11, 11, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm12, 12, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm13, 13, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm14, 14, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm15, 15, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm16, 16, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm17, 17, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm18, 18, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm19, 19, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm20, 20, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm21, 21, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm22, 22, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm23, 23, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm24, 24, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm25, 25, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm26, 26, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm27, 27, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm28, 28, r9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(zmm29, 29, r9)
|
|
||||||
|
|
||||||
END:
|
END:
|
||||||
#ifdef MONITORS
|
#ifdef MONITORS
|
||||||
@@ -567,6 +524,8 @@ void bli_sgemm_knc_asm_30x16
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GEMM_UKR_FLUSH_CT( s );
|
||||||
|
|
||||||
#ifdef LOOPMON
|
#ifdef LOOPMON
|
||||||
printf("looptime = \t%d\n", bloopl - tloopl);
|
printf("looptime = \t%d\n", bloopl - tloopl);
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -185,6 +185,8 @@ static int32_t offsets[32] __attribute__((aligned(64))) =
|
|||||||
//#define LOOPMON
|
//#define LOOPMON
|
||||||
void bli_dgemm_knl_asm_24x8
|
void bli_dgemm_knl_asm_24x8
|
||||||
(
|
(
|
||||||
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
dim_t k_,
|
dim_t k_,
|
||||||
double* restrict alpha,
|
double* restrict alpha,
|
||||||
double* restrict a,
|
double* restrict a,
|
||||||
@@ -201,10 +203,12 @@ void bli_dgemm_knl_asm_24x8
|
|||||||
const double * a_next = bli_auxinfo_next_a( data );
|
const double * a_next = bli_auxinfo_next_a( data );
|
||||||
const double * b_next = bli_auxinfo_next_b( data );
|
const double * b_next = bli_auxinfo_next_b( data );
|
||||||
|
|
||||||
const int32_t * offsetPtr = &offsets[0];
|
int32_t * offsetPtr = &offsets[0];
|
||||||
const int64_t k = k_;
|
int64_t k = k_;
|
||||||
const int64_t rs_c = rs_c_;
|
int64_t rs_c = rs_c_;
|
||||||
const int64_t cs_c = cs_c_;
|
int64_t cs_c = cs_c_;
|
||||||
|
|
||||||
|
GEMM_UKR_SETUP_CT( d, 24, 8, true );
|
||||||
|
|
||||||
#ifdef MONITORS
|
#ifdef MONITORS
|
||||||
int toph, topl, both, botl, midl, midh, mid2l, mid2h;
|
int toph, topl, both, botl, midl, midh, mid2l, mid2h;
|
||||||
@@ -565,10 +569,7 @@ void bli_dgemm_knl_asm_24x8
|
|||||||
// Check if C is row stride. If not, jump to the slow scattered update
|
// Check if C is row stride. If not, jump to the slow scattered update
|
||||||
MOV(RAX, VAR(rs_c))
|
MOV(RAX, VAR(rs_c))
|
||||||
LEA(RAX, MEM(,RAX,8))
|
LEA(RAX, MEM(,RAX,8))
|
||||||
MOV(RBX, VAR(cs_c))
|
|
||||||
LEA(RDI, MEM(RAX,RAX,2))
|
LEA(RDI, MEM(RAX,RAX,2))
|
||||||
CMP(RBX, IMM(1))
|
|
||||||
JNE(SCATTEREDUPDATE)
|
|
||||||
|
|
||||||
VMOVQ(RDX, XMM(1))
|
VMOVQ(RDX, XMM(1))
|
||||||
SAL(RDX) //shift out sign bit
|
SAL(RDX) //shift out sign bit
|
||||||
@@ -592,74 +593,6 @@ void bli_dgemm_knl_asm_24x8
|
|||||||
UPDATE_C_BZ_FOUR_ROWS(24,25,26,27)
|
UPDATE_C_BZ_FOUR_ROWS(24,25,26,27)
|
||||||
UPDATE_C_BZ_FOUR_ROWS(28,29,30,31)
|
UPDATE_C_BZ_FOUR_ROWS(28,29,30,31)
|
||||||
|
|
||||||
JMP(END)
|
|
||||||
|
|
||||||
LABEL(SCATTEREDUPDATE)
|
|
||||||
|
|
||||||
MOV(RDI, VAR(offsetPtr))
|
|
||||||
VMOVAPS(ZMM(2), MEM(RDI))
|
|
||||||
/* Note that this ignores the upper 32 bits in cs_c */
|
|
||||||
VPBROADCASTD(ZMM(3), EBX)
|
|
||||||
VPMULLD(ZMM(2), ZMM(3), ZMM(2))
|
|
||||||
|
|
||||||
VMOVQ(RDX, XMM(1))
|
|
||||||
SAL(RDX) //shift out sign bit
|
|
||||||
JZ(SCATTERBZ)
|
|
||||||
|
|
||||||
UPDATE_C_ROW_SCATTERED( 8)
|
|
||||||
UPDATE_C_ROW_SCATTERED( 9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(10)
|
|
||||||
UPDATE_C_ROW_SCATTERED(11)
|
|
||||||
UPDATE_C_ROW_SCATTERED(12)
|
|
||||||
UPDATE_C_ROW_SCATTERED(13)
|
|
||||||
UPDATE_C_ROW_SCATTERED(14)
|
|
||||||
UPDATE_C_ROW_SCATTERED(15)
|
|
||||||
UPDATE_C_ROW_SCATTERED(16)
|
|
||||||
UPDATE_C_ROW_SCATTERED(17)
|
|
||||||
UPDATE_C_ROW_SCATTERED(18)
|
|
||||||
UPDATE_C_ROW_SCATTERED(19)
|
|
||||||
UPDATE_C_ROW_SCATTERED(20)
|
|
||||||
UPDATE_C_ROW_SCATTERED(21)
|
|
||||||
UPDATE_C_ROW_SCATTERED(22)
|
|
||||||
UPDATE_C_ROW_SCATTERED(23)
|
|
||||||
UPDATE_C_ROW_SCATTERED(24)
|
|
||||||
UPDATE_C_ROW_SCATTERED(25)
|
|
||||||
UPDATE_C_ROW_SCATTERED(26)
|
|
||||||
UPDATE_C_ROW_SCATTERED(27)
|
|
||||||
UPDATE_C_ROW_SCATTERED(28)
|
|
||||||
UPDATE_C_ROW_SCATTERED(29)
|
|
||||||
UPDATE_C_ROW_SCATTERED(30)
|
|
||||||
UPDATE_C_ROW_SCATTERED(31)
|
|
||||||
|
|
||||||
JMP(END)
|
|
||||||
|
|
||||||
LABEL(SCATTERBZ)
|
|
||||||
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED( 8)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED( 9)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(10)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(11)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(12)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(13)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(14)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(15)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(16)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(17)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(18)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(19)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(20)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(21)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(22)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(23)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(24)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(25)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(26)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(27)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(28)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(29)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(30)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(31)
|
|
||||||
|
|
||||||
LABEL(END)
|
LABEL(END)
|
||||||
|
|
||||||
#ifdef MONITORS
|
#ifdef MONITORS
|
||||||
@@ -701,6 +634,8 @@ void bli_dgemm_knl_asm_24x8
|
|||||||
"zmm30", "zmm31", "memory"
|
"zmm30", "zmm31", "memory"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
GEMM_UKR_FLUSH_CT( d );
|
||||||
|
|
||||||
#ifdef LOOPMON
|
#ifdef LOOPMON
|
||||||
printf("looptime = \t%d\n", bloopl - tloopl);
|
printf("looptime = \t%d\n", bloopl - tloopl);
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -182,6 +182,8 @@ static int32_t offsets[32] __attribute__((aligned(64))) =
|
|||||||
//#define LOOPMON
|
//#define LOOPMON
|
||||||
void bli_sgemm_knl_asm_24x16
|
void bli_sgemm_knl_asm_24x16
|
||||||
(
|
(
|
||||||
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
dim_t k_,
|
dim_t k_,
|
||||||
float* restrict alpha,
|
float* restrict alpha,
|
||||||
float* restrict a,
|
float* restrict a,
|
||||||
@@ -198,10 +200,12 @@ void bli_sgemm_knl_asm_24x16
|
|||||||
const double * a_next = bli_auxinfo_next_a( data );
|
const double * a_next = bli_auxinfo_next_a( data );
|
||||||
const double * b_next = bli_auxinfo_next_b( data );
|
const double * b_next = bli_auxinfo_next_b( data );
|
||||||
|
|
||||||
const int32_t * offsetPtr = &offsets[0];
|
int32_t * offsetPtr = &offsets[0];
|
||||||
const int64_t k = k_;
|
int64_t k = k_;
|
||||||
const int64_t rs_c = rs_c_;
|
int64_t rs_c = rs_c_;
|
||||||
const int64_t cs_c = cs_c_;
|
int64_t cs_c = cs_c_;
|
||||||
|
|
||||||
|
GEMM_UKR_SETUP_CT( s, 24, 16, true );
|
||||||
|
|
||||||
#ifdef MONITORS
|
#ifdef MONITORS
|
||||||
int toph, topl, both, botl, midl, midh, mid2l, mid2h;
|
int toph, topl, both, botl, midl, midh, mid2l, mid2h;
|
||||||
@@ -562,10 +566,7 @@ void bli_sgemm_knl_asm_24x16
|
|||||||
// Check if C is row stride. If not, jump to the slow scattered update
|
// Check if C is row stride. If not, jump to the slow scattered update
|
||||||
MOV(RAX, VAR(rs_c))
|
MOV(RAX, VAR(rs_c))
|
||||||
LEA(RAX, MEM(,RAX,4))
|
LEA(RAX, MEM(,RAX,4))
|
||||||
MOV(RBX, VAR(cs_c))
|
|
||||||
LEA(RDI, MEM(RAX,RAX,2))
|
LEA(RDI, MEM(RAX,RAX,2))
|
||||||
CMP(RBX, IMM(1))
|
|
||||||
JNE(SCATTEREDUPDATE)
|
|
||||||
|
|
||||||
VMOVD(EDX, XMM(1))
|
VMOVD(EDX, XMM(1))
|
||||||
SAL(EDX) //shift out sign bit
|
SAL(EDX) //shift out sign bit
|
||||||
@@ -589,74 +590,6 @@ void bli_sgemm_knl_asm_24x16
|
|||||||
UPDATE_C_BZ_FOUR_ROWS(24,25,26,27)
|
UPDATE_C_BZ_FOUR_ROWS(24,25,26,27)
|
||||||
UPDATE_C_BZ_FOUR_ROWS(28,29,30,31)
|
UPDATE_C_BZ_FOUR_ROWS(28,29,30,31)
|
||||||
|
|
||||||
JMP(END)
|
|
||||||
|
|
||||||
LABEL(SCATTEREDUPDATE)
|
|
||||||
|
|
||||||
MOV(RDI, VAR(offsetPtr))
|
|
||||||
VMOVAPS(ZMM(2), MEM(RDI))
|
|
||||||
/* Note that this ignores the upper 32 bits in cs_c */
|
|
||||||
VPBROADCASTD(ZMM(3), EBX)
|
|
||||||
VPMULLD(ZMM(2), ZMM(3), ZMM(2))
|
|
||||||
|
|
||||||
VMOVD(EDX, XMM(1))
|
|
||||||
SAL(EDX) //shift out sign bit
|
|
||||||
JZ(SCATTERBZ)
|
|
||||||
|
|
||||||
UPDATE_C_ROW_SCATTERED( 8)
|
|
||||||
UPDATE_C_ROW_SCATTERED( 9)
|
|
||||||
UPDATE_C_ROW_SCATTERED(10)
|
|
||||||
UPDATE_C_ROW_SCATTERED(11)
|
|
||||||
UPDATE_C_ROW_SCATTERED(12)
|
|
||||||
UPDATE_C_ROW_SCATTERED(13)
|
|
||||||
UPDATE_C_ROW_SCATTERED(14)
|
|
||||||
UPDATE_C_ROW_SCATTERED(15)
|
|
||||||
UPDATE_C_ROW_SCATTERED(16)
|
|
||||||
UPDATE_C_ROW_SCATTERED(17)
|
|
||||||
UPDATE_C_ROW_SCATTERED(18)
|
|
||||||
UPDATE_C_ROW_SCATTERED(19)
|
|
||||||
UPDATE_C_ROW_SCATTERED(20)
|
|
||||||
UPDATE_C_ROW_SCATTERED(21)
|
|
||||||
UPDATE_C_ROW_SCATTERED(22)
|
|
||||||
UPDATE_C_ROW_SCATTERED(23)
|
|
||||||
UPDATE_C_ROW_SCATTERED(24)
|
|
||||||
UPDATE_C_ROW_SCATTERED(25)
|
|
||||||
UPDATE_C_ROW_SCATTERED(26)
|
|
||||||
UPDATE_C_ROW_SCATTERED(27)
|
|
||||||
UPDATE_C_ROW_SCATTERED(28)
|
|
||||||
UPDATE_C_ROW_SCATTERED(29)
|
|
||||||
UPDATE_C_ROW_SCATTERED(30)
|
|
||||||
UPDATE_C_ROW_SCATTERED(31)
|
|
||||||
|
|
||||||
JMP(END)
|
|
||||||
|
|
||||||
LABEL(SCATTERBZ)
|
|
||||||
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED( 8)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED( 9)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(10)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(11)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(12)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(13)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(14)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(15)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(16)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(17)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(18)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(19)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(20)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(21)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(22)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(23)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(24)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(25)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(26)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(27)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(28)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(29)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(30)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(31)
|
|
||||||
|
|
||||||
LABEL(END)
|
LABEL(END)
|
||||||
|
|
||||||
#ifdef MONITORS
|
#ifdef MONITORS
|
||||||
@@ -698,6 +631,8 @@ void bli_sgemm_knl_asm_24x16
|
|||||||
"zmm30", "zmm31", "memory"
|
"zmm30", "zmm31", "memory"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
GEMM_UKR_FLUSH_CT( s );
|
||||||
|
|
||||||
#ifdef LOOPMON
|
#ifdef LOOPMON
|
||||||
printf("looptime = \t%d\n", bloopl - tloopl);
|
printf("looptime = \t%d\n", bloopl - tloopl);
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -37,7 +37,7 @@
|
|||||||
|
|
||||||
#define D_ASSEMBLE_VEC_PAIR \
|
#define D_ASSEMBLE_VEC_PAIR \
|
||||||
__builtin_mma_assemble_pair (&colA_1, ca[1], ca[0]); \
|
__builtin_mma_assemble_pair (&colA_1, ca[1], ca[0]); \
|
||||||
__builtin_mma_assemble_pair (&colA_2, ca[3], ca[2]);
|
__builtin_mma_assemble_pair (&colA_2, ca[3], ca[2]);
|
||||||
|
|
||||||
#define D_ACCUMULATE \
|
#define D_ACCUMULATE \
|
||||||
__builtin_mma_xvf64gerpp (&acc0, colA_1, rb[0]); \
|
__builtin_mma_xvf64gerpp (&acc0, colA_1, rb[0]); \
|
||||||
@@ -47,7 +47,7 @@
|
|||||||
__builtin_mma_xvf64gerpp (&acc4, colA_2, rb[0]); \
|
__builtin_mma_xvf64gerpp (&acc4, colA_2, rb[0]); \
|
||||||
__builtin_mma_xvf64gerpp (&acc5, colA_2, rb[1]); \
|
__builtin_mma_xvf64gerpp (&acc5, colA_2, rb[1]); \
|
||||||
__builtin_mma_xvf64gerpp (&acc6, colA_2, rb[2]); \
|
__builtin_mma_xvf64gerpp (&acc6, colA_2, rb[2]); \
|
||||||
__builtin_mma_xvf64gerpp (&acc7, colA_2, rb[3]);
|
__builtin_mma_xvf64gerpp (&acc7, colA_2, rb[3]);
|
||||||
|
|
||||||
#define D_INCREMENT \
|
#define D_INCREMENT \
|
||||||
A0+=8; \
|
A0+=8; \
|
||||||
@@ -57,17 +57,19 @@
|
|||||||
LOAD_VECTORS \
|
LOAD_VECTORS \
|
||||||
D_ASSEMBLE_VEC_PAIR \
|
D_ASSEMBLE_VEC_PAIR \
|
||||||
D_INCREMENT \
|
D_INCREMENT \
|
||||||
D_ACCUMULATE
|
D_ACCUMULATE
|
||||||
|
|
||||||
|
|
||||||
void bli_dgemm_power10_mma_8x8
|
void bli_dgemm_power10_mma_8x8
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
double* restrict alpha,
|
double* restrict alpha,
|
||||||
double* restrict a,
|
double* restrict a,
|
||||||
double* restrict b,
|
double* restrict b,
|
||||||
double* restrict beta,
|
double* restrict beta,
|
||||||
double* restrict c, inc_t rs_c0, inc_t cs_c0,
|
double* restrict c, inc_t rs_c0, inc_t cs_c,
|
||||||
auxinfo_t* restrict data,
|
auxinfo_t* restrict data,
|
||||||
cntx_t* restrict cntx
|
cntx_t* restrict cntx
|
||||||
)
|
)
|
||||||
@@ -76,11 +78,13 @@ void bli_dgemm_power10_mma_8x8
|
|||||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||||
// different size than is expected by load instructions.
|
// different size than is expected by load instructions.
|
||||||
// (1 is subtracted from k0 because 1 iteration of the k loop is pulled out)
|
// (1 is subtracted from k0 because 1 iteration of the k loop is pulled out)
|
||||||
uint64_t k_iter = (k0-1) / 4;
|
uint64_t k_iter = (k-1) / 4;
|
||||||
uint64_t k_left = (k0-1) % 4;
|
uint64_t k_left = (k-1) % 4;
|
||||||
|
|
||||||
uint64_t rs_c = rs_c0;
|
uint64_t rs_c = rs_c0;
|
||||||
|
|
||||||
|
GEMM_UKR_SETUP_CT( d, 8, 8, true );
|
||||||
|
|
||||||
double* restrict A0 = a;
|
double* restrict A0 = a;
|
||||||
double* restrict B0 = b;
|
double* restrict B0 = b;
|
||||||
double* restrict C0 = c;
|
double* restrict C0 = c;
|
||||||
@@ -92,23 +96,23 @@ void bli_dgemm_power10_mma_8x8
|
|||||||
dv4sf_t *rowC;
|
dv4sf_t *rowC;
|
||||||
|
|
||||||
/* 8 accumulator registers that will be used to store the result.
|
/* 8 accumulator registers that will be used to store the result.
|
||||||
|
|
||||||
Each accumulator register is mapped to 4 vector registers.
|
Each accumulator register is mapped to 4 vector registers.
|
||||||
Illustration:
|
Illustration:
|
||||||
|
|
||||||
acc0 = [ vs0
|
acc0 = [ vs0
|
||||||
vs1
|
vs1
|
||||||
vs3
|
vs3
|
||||||
vs4 ]
|
vs4 ]
|
||||||
|
|
||||||
These registers are used to store the result of an outer product
|
These registers are used to store the result of an outer product
|
||||||
instruction (general outer product instruction syntax: xv???ger??). */
|
instruction (general outer product instruction syntax: xv???ger??). */
|
||||||
__vector_quad acc0, acc1, acc2, acc3,
|
__vector_quad acc0, acc1, acc2, acc3,
|
||||||
acc4, acc5, acc6, acc7;
|
acc4, acc5, acc6, acc7;
|
||||||
|
|
||||||
/* 2 vector pairs are necessary for a double precision outer product
|
/* 2 vector pairs are necessary for a double precision outer product
|
||||||
instruction. */
|
instruction. */
|
||||||
__vector_pair colA_1,
|
__vector_pair colA_1,
|
||||||
colA_2;
|
colA_2;
|
||||||
|
|
||||||
/* Prefetch C so that it stays in cache */
|
/* Prefetch C so that it stays in cache */
|
||||||
@@ -123,17 +127,17 @@ void bli_dgemm_power10_mma_8x8
|
|||||||
|
|
||||||
/* Load elements into vector registers */
|
/* Load elements into vector registers */
|
||||||
vec_t *ca = (vec_t *) A0;
|
vec_t *ca = (vec_t *) A0;
|
||||||
vec_t *rb = (vec_t *) B0;
|
vec_t *rb = (vec_t *) B0;
|
||||||
|
|
||||||
/* Each accumulator represents a matrix of size
|
/* Each accumulator represents a matrix of size
|
||||||
4 x ( 16 / (datatype size in bytes) ) (vector register size = 16B)
|
4 x ( 16 / (datatype size in bytes) ) (vector register size = 16B)
|
||||||
|
|
||||||
Thus in the case of double, the accumulate registers represent a 4x2
|
Thus in the case of double, the accumulate registers represent a 4x2
|
||||||
matrix. However, a vector register can hold at most 2 doubles. Thus, if
|
matrix. However, a vector register can hold at most 2 doubles. Thus, if
|
||||||
we performed an outer product using 2 vector register, we can only get a
|
we performed an outer product using 2 vector register, we can only get a
|
||||||
2x2 matrix. Therefore, we must create a vector register pair in order
|
2x2 matrix. Therefore, we must create a vector register pair in order
|
||||||
to get the desired 4x2 matrix.
|
to get the desired 4x2 matrix.
|
||||||
|
|
||||||
*/
|
*/
|
||||||
D_ASSEMBLE_VEC_PAIR
|
D_ASSEMBLE_VEC_PAIR
|
||||||
|
|
||||||
@@ -158,7 +162,7 @@ void bli_dgemm_power10_mma_8x8
|
|||||||
D_AB_PRODUCT
|
D_AB_PRODUCT
|
||||||
D_AB_PRODUCT
|
D_AB_PRODUCT
|
||||||
}
|
}
|
||||||
|
|
||||||
// edge loop
|
// edge loop
|
||||||
for (int k = 0; k<k_left; k++)
|
for (int k = 0; k<k_left; k++)
|
||||||
{
|
{
|
||||||
@@ -189,4 +193,5 @@ void bli_dgemm_power10_mma_8x8
|
|||||||
SAVE_ACC_bz(dv4sf_t, &acc7, rs_c, 6+4*rs_c);
|
SAVE_ACC_bz(dv4sf_t, &acc7, rs_c, 6+4*rs_c);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GEMM_UKR_FLUSH_CT( d );
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -55,7 +55,9 @@
|
|||||||
|
|
||||||
void bli_i16gemm_power10_mma_8x16
|
void bli_i16gemm_power10_mma_8x16
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
int32_t* restrict alpha,
|
int32_t* restrict alpha,
|
||||||
short* restrict a,
|
short* restrict a,
|
||||||
short* restrict b,
|
short* restrict b,
|
||||||
@@ -66,8 +68,8 @@ void bli_i16gemm_power10_mma_8x16
|
|||||||
)
|
)
|
||||||
{
|
{
|
||||||
|
|
||||||
uint64_t k_iter = (k0-1) / 4;
|
uint64_t k_iter = (k-1) / 4;
|
||||||
uint64_t k_left = (k0-1) % 4;
|
uint64_t k_left = (k-1) % 4;
|
||||||
|
|
||||||
uint64_t rs_c = rs_c0;
|
uint64_t rs_c = rs_c0;
|
||||||
|
|
||||||
@@ -82,7 +84,7 @@ void bli_i16gemm_power10_mma_8x16
|
|||||||
iv4sf_t *rowC;
|
iv4sf_t *rowC;
|
||||||
|
|
||||||
// accumulators that will hold the matrix product
|
// accumulators that will hold the matrix product
|
||||||
__vector_quad acc0, acc1, acc2, acc3,
|
__vector_quad acc0, acc1, acc2, acc3,
|
||||||
acc4, acc5, acc6, acc7;
|
acc4, acc5, acc6, acc7;
|
||||||
|
|
||||||
vec_t *ca = (vec_t *) A0;
|
vec_t *ca = (vec_t *) A0;
|
||||||
|
|||||||
@@ -55,7 +55,9 @@
|
|||||||
|
|
||||||
void bli_i16sgemm_power10_mma_8x16
|
void bli_i16sgemm_power10_mma_8x16
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
int32_t* restrict alpha,
|
int32_t* restrict alpha,
|
||||||
short* restrict a,
|
short* restrict a,
|
||||||
short* restrict b,
|
short* restrict b,
|
||||||
@@ -66,8 +68,8 @@ void bli_i16sgemm_power10_mma_8x16
|
|||||||
)
|
)
|
||||||
{
|
{
|
||||||
|
|
||||||
uint64_t k_iter = (k0-1) / 4;
|
uint64_t k_iter = (k-1) / 4;
|
||||||
uint64_t k_left = (k0-1) % 4;
|
uint64_t k_left = (k-1) % 4;
|
||||||
|
|
||||||
uint64_t rs_c = rs_c0;
|
uint64_t rs_c = rs_c0;
|
||||||
|
|
||||||
@@ -82,7 +84,7 @@ void bli_i16sgemm_power10_mma_8x16
|
|||||||
iv4sf_t *rowC;
|
iv4sf_t *rowC;
|
||||||
|
|
||||||
// accumulators that will hold the matrix product
|
// accumulators that will hold the matrix product
|
||||||
__vector_quad acc0, acc1, acc2, acc3,
|
__vector_quad acc0, acc1, acc2, acc3,
|
||||||
acc4, acc5, acc6, acc7;
|
acc4, acc5, acc6, acc7;
|
||||||
|
|
||||||
vec_t *ca = (vec_t *) A0;
|
vec_t *ca = (vec_t *) A0;
|
||||||
|
|||||||
@@ -55,7 +55,9 @@
|
|||||||
|
|
||||||
void bli_i4gemm_power10_mma_8x16
|
void bli_i4gemm_power10_mma_8x16
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
int32_t* restrict alpha,
|
int32_t* restrict alpha,
|
||||||
nibbles* restrict a,
|
nibbles* restrict a,
|
||||||
nibbles* restrict b,
|
nibbles* restrict b,
|
||||||
@@ -66,8 +68,8 @@ void bli_i4gemm_power10_mma_8x16
|
|||||||
)
|
)
|
||||||
{
|
{
|
||||||
|
|
||||||
uint64_t k_iter = (k0-1) / 4;
|
uint64_t k_iter = (k-1) / 4;
|
||||||
uint64_t k_left = (k0-1) % 4;
|
uint64_t k_left = (k-1) % 4;
|
||||||
|
|
||||||
uint64_t rs_c = rs_c0;
|
uint64_t rs_c = rs_c0;
|
||||||
|
|
||||||
@@ -82,11 +84,11 @@ void bli_i4gemm_power10_mma_8x16
|
|||||||
iv4sf_t *rowC;
|
iv4sf_t *rowC;
|
||||||
|
|
||||||
// accumulators that will hold the matrix product
|
// accumulators that will hold the matrix product
|
||||||
__vector_quad acc0, acc1, acc2, acc3,
|
__vector_quad acc0, acc1, acc2, acc3,
|
||||||
acc4, acc5, acc6, acc7;
|
acc4, acc5, acc6, acc7;
|
||||||
|
|
||||||
vec_t *ca = (vec_t *) A0;
|
vec_t *ca = (vec_t *) A0;
|
||||||
vec_t *rb = (vec_t *) B0;
|
vec_t *rb = (vec_t *) B0;
|
||||||
|
|
||||||
__builtin_mma_xvi4ger8 (&acc0, ca[0], rb[0]);
|
__builtin_mma_xvi4ger8 (&acc0, ca[0], rb[0]);
|
||||||
__builtin_mma_xvi4ger8 (&acc1, ca[0], rb[1]);
|
__builtin_mma_xvi4ger8 (&acc1, ca[0], rb[1]);
|
||||||
@@ -96,23 +98,23 @@ void bli_i4gemm_power10_mma_8x16
|
|||||||
__builtin_mma_xvi4ger8 (&acc5, ca[1], rb[1]);
|
__builtin_mma_xvi4ger8 (&acc5, ca[1], rb[1]);
|
||||||
__builtin_mma_xvi4ger8 (&acc6, ca[1], rb[2]);
|
__builtin_mma_xvi4ger8 (&acc6, ca[1], rb[2]);
|
||||||
__builtin_mma_xvi4ger8 (&acc7, ca[1], rb[3]);
|
__builtin_mma_xvi4ger8 (&acc7, ca[1], rb[3]);
|
||||||
|
|
||||||
I4_INCREMENT
|
I4_INCREMENT
|
||||||
|
|
||||||
// k loop (unrolled by 4)
|
// k loop (unrolled by 4)
|
||||||
for (int k = 0; k<k_iter; k++)
|
for (int k = 0; k<k_iter; k++)
|
||||||
{
|
{
|
||||||
I4_AB_PRODUCT
|
I4_AB_PRODUCT
|
||||||
I4_AB_PRODUCT
|
I4_AB_PRODUCT
|
||||||
I4_AB_PRODUCT
|
I4_AB_PRODUCT
|
||||||
I4_AB_PRODUCT
|
I4_AB_PRODUCT
|
||||||
}
|
}
|
||||||
|
|
||||||
// edge loop
|
// edge loop
|
||||||
for (int k = 0; k<k_left; k++)
|
for (int k = 0; k<k_left; k++)
|
||||||
{
|
{
|
||||||
I4_AB_PRODUCT
|
I4_AB_PRODUCT
|
||||||
}
|
}
|
||||||
|
|
||||||
// handle beta cases
|
// handle beta cases
|
||||||
if (beta_ != 0.0)
|
if (beta_ != 0.0)
|
||||||
|
|||||||
@@ -55,7 +55,9 @@
|
|||||||
|
|
||||||
void bli_i8gemm_power10_mma_8x16
|
void bli_i8gemm_power10_mma_8x16
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
int32_t* restrict alpha,
|
int32_t* restrict alpha,
|
||||||
int8_t* restrict a,
|
int8_t* restrict a,
|
||||||
int8_t* restrict b,
|
int8_t* restrict b,
|
||||||
@@ -65,8 +67,8 @@ void bli_i8gemm_power10_mma_8x16
|
|||||||
cntx_t* restrict cntx
|
cntx_t* restrict cntx
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
uint64_t k_iter = (k0-1) / 4;
|
uint64_t k_iter = (k-1) / 4;
|
||||||
uint64_t k_left = (k0-1) % 4;
|
uint64_t k_left = (k-1) % 4;
|
||||||
|
|
||||||
uint64_t rs_c = rs_c0;
|
uint64_t rs_c = rs_c0;
|
||||||
|
|
||||||
@@ -81,11 +83,11 @@ void bli_i8gemm_power10_mma_8x16
|
|||||||
iv4sf_t *rowC;
|
iv4sf_t *rowC;
|
||||||
|
|
||||||
// accumulators that will hold the matrix product
|
// accumulators that will hold the matrix product
|
||||||
__vector_quad acc0, acc1, acc2, acc3,
|
__vector_quad acc0, acc1, acc2, acc3,
|
||||||
acc4, acc5, acc6, acc7;
|
acc4, acc5, acc6, acc7;
|
||||||
|
|
||||||
vec_t *ca = (vec_t *) A0;
|
vec_t *ca = (vec_t *) A0;
|
||||||
vec_t *rb = (vec_t *) B0;
|
vec_t *rb = (vec_t *) B0;
|
||||||
|
|
||||||
__builtin_mma_xvi8ger4 (&acc0, ca[0], rb[0]);
|
__builtin_mma_xvi8ger4 (&acc0, ca[0], rb[0]);
|
||||||
__builtin_mma_xvi8ger4 (&acc1, ca[0], rb[1]);
|
__builtin_mma_xvi8ger4 (&acc1, ca[0], rb[1]);
|
||||||
@@ -99,19 +101,19 @@ void bli_i8gemm_power10_mma_8x16
|
|||||||
I8_INCREMENT
|
I8_INCREMENT
|
||||||
|
|
||||||
// k loop (unrolled by 4)
|
// k loop (unrolled by 4)
|
||||||
for (int k = 0; k<k_iter; k++)
|
for (int k = 0; k<k_iter; k++)
|
||||||
{
|
{
|
||||||
I8_AB_PRODUCT
|
I8_AB_PRODUCT
|
||||||
I8_AB_PRODUCT
|
I8_AB_PRODUCT
|
||||||
I8_AB_PRODUCT
|
I8_AB_PRODUCT
|
||||||
I8_AB_PRODUCT
|
I8_AB_PRODUCT
|
||||||
}
|
}
|
||||||
|
|
||||||
// edge loop
|
// edge loop
|
||||||
for (int k = 0; k<k_left; k++)
|
for (int k = 0; k<k_left; k++)
|
||||||
{
|
{
|
||||||
I8_AB_PRODUCT
|
I8_AB_PRODUCT
|
||||||
}
|
}
|
||||||
|
|
||||||
// handle beta cases
|
// handle beta cases
|
||||||
if (beta_ != 0.0)
|
if (beta_ != 0.0)
|
||||||
|
|||||||
@@ -42,21 +42,23 @@
|
|||||||
__builtin_mma_xvbf16ger2pp (&acc4, ca[1], rb[0]); \
|
__builtin_mma_xvbf16ger2pp (&acc4, ca[1], rb[0]); \
|
||||||
__builtin_mma_xvbf16ger2pp (&acc5, ca[1], rb[1]); \
|
__builtin_mma_xvbf16ger2pp (&acc5, ca[1], rb[1]); \
|
||||||
__builtin_mma_xvbf16ger2pp (&acc6, ca[1], rb[2]); \
|
__builtin_mma_xvbf16ger2pp (&acc6, ca[1], rb[2]); \
|
||||||
__builtin_mma_xvbf16ger2pp (&acc7, ca[1], rb[3]);
|
__builtin_mma_xvbf16ger2pp (&acc7, ca[1], rb[3]);
|
||||||
|
|
||||||
#define B_INCREMENT \
|
#define B_INCREMENT \
|
||||||
A0+=16; \
|
A0+=16; \
|
||||||
B0+=32;
|
B0+=32;
|
||||||
|
|
||||||
#define B_AB_PRODUCT \
|
#define B_AB_PRODUCT \
|
||||||
LOAD_VECTORS \
|
LOAD_VECTORS \
|
||||||
B_INCREMENT \
|
B_INCREMENT \
|
||||||
B_ACCUMULATE
|
B_ACCUMULATE
|
||||||
|
|
||||||
|
|
||||||
void bli_sbgemm_power10_mma_8x16
|
void bli_sbgemm_power10_mma_8x16
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
float* restrict alpha,
|
float* restrict alpha,
|
||||||
bfloat16* restrict a,
|
bfloat16* restrict a,
|
||||||
bfloat16* restrict b,
|
bfloat16* restrict b,
|
||||||
@@ -67,8 +69,8 @@ void bli_sbgemm_power10_mma_8x16
|
|||||||
)
|
)
|
||||||
{
|
{
|
||||||
|
|
||||||
uint64_t k_iter = (k0-1)/4;
|
uint64_t k_iter = (k-1)/4;
|
||||||
uint64_t k_left = (k0-1)%4;
|
uint64_t k_left = (k-1)%4;
|
||||||
|
|
||||||
uint64_t rs_c = rs_c0;
|
uint64_t rs_c = rs_c0;
|
||||||
|
|
||||||
@@ -83,7 +85,7 @@ void bli_sbgemm_power10_mma_8x16
|
|||||||
fv4sf_t *rowC;
|
fv4sf_t *rowC;
|
||||||
|
|
||||||
// accumulators that will hold the matrix product
|
// accumulators that will hold the matrix product
|
||||||
__vector_quad acc0, acc1, acc2, acc3,
|
__vector_quad acc0, acc1, acc2, acc3,
|
||||||
acc4, acc5, acc6, acc7;
|
acc4, acc5, acc6, acc7;
|
||||||
|
|
||||||
vec_t *ca = (vec_t *) A0;
|
vec_t *ca = (vec_t *) A0;
|
||||||
|
|||||||
@@ -42,7 +42,7 @@
|
|||||||
__builtin_mma_xvf32gerpp (&acc4, ca[1], rb[0]); \
|
__builtin_mma_xvf32gerpp (&acc4, ca[1], rb[0]); \
|
||||||
__builtin_mma_xvf32gerpp (&acc5, ca[1], rb[1]); \
|
__builtin_mma_xvf32gerpp (&acc5, ca[1], rb[1]); \
|
||||||
__builtin_mma_xvf32gerpp (&acc6, ca[1], rb[2]); \
|
__builtin_mma_xvf32gerpp (&acc6, ca[1], rb[2]); \
|
||||||
__builtin_mma_xvf32gerpp (&acc7, ca[1], rb[3]);
|
__builtin_mma_xvf32gerpp (&acc7, ca[1], rb[3]);
|
||||||
|
|
||||||
#define S_INCREMENT \
|
#define S_INCREMENT \
|
||||||
A0+=8; \
|
A0+=8; \
|
||||||
@@ -51,16 +51,18 @@
|
|||||||
#define S_AB_PRODUCT \
|
#define S_AB_PRODUCT \
|
||||||
LOAD_VECTORS \
|
LOAD_VECTORS \
|
||||||
S_INCREMENT \
|
S_INCREMENT \
|
||||||
S_ACCUMULATE
|
S_ACCUMULATE
|
||||||
|
|
||||||
void bli_sgemm_power10_mma_8x16
|
void bli_sgemm_power10_mma_8x16
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
float* restrict alpha,
|
float* restrict alpha,
|
||||||
float* restrict a,
|
float* restrict a,
|
||||||
float* restrict b,
|
float* restrict b,
|
||||||
float* restrict beta,
|
float* restrict beta,
|
||||||
float* restrict c, inc_t rs_c0, inc_t cs_c0,
|
float* restrict c, inc_t rs_c0, inc_t cs_c,
|
||||||
auxinfo_t* restrict data,
|
auxinfo_t* restrict data,
|
||||||
cntx_t* restrict cntx
|
cntx_t* restrict cntx
|
||||||
)
|
)
|
||||||
@@ -68,16 +70,18 @@ void bli_sgemm_power10_mma_8x16
|
|||||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||||
// different size than is expected by load instructions.
|
// different size than is expected by load instructions.
|
||||||
// (1 is subtracted from k0 because 1 iteration of the k loop is pulled out)
|
// (1 is subtracted from k0 because 1 iteration of the k loop is pulled out)
|
||||||
uint64_t k_iter = (k0-1) / 4;
|
uint64_t k_iter = (k-1) / 4;
|
||||||
uint64_t k_left = (k0-1) % 4;
|
uint64_t k_left = (k-1) % 4;
|
||||||
|
|
||||||
uint64_t rs_c = rs_c0;
|
uint64_t rs_c = rs_c0;
|
||||||
|
|
||||||
|
GEMM_UKR_SETUP_CT( s, 8, 16, true );
|
||||||
|
|
||||||
fv4sf_t result[4];
|
fv4sf_t result[4];
|
||||||
fv4sf_t *rowC;
|
fv4sf_t *rowC;
|
||||||
|
|
||||||
// accumulators that will hold the matrix product
|
// accumulators that will hold the matrix product
|
||||||
__vector_quad acc0, acc1, acc2, acc3,
|
__vector_quad acc0, acc1, acc2, acc3,
|
||||||
acc4, acc5, acc6, acc7;
|
acc4, acc5, acc6, acc7;
|
||||||
|
|
||||||
float* restrict A0 = a;
|
float* restrict A0 = a;
|
||||||
@@ -111,7 +115,7 @@ void bli_sgemm_power10_mma_8x16
|
|||||||
S_AB_PRODUCT
|
S_AB_PRODUCT
|
||||||
S_AB_PRODUCT
|
S_AB_PRODUCT
|
||||||
}
|
}
|
||||||
|
|
||||||
// edge loop
|
// edge loop
|
||||||
for (int k = 0; k<k_left; k++)
|
for (int k = 0; k<k_left; k++)
|
||||||
{
|
{
|
||||||
@@ -141,4 +145,6 @@ void bli_sgemm_power10_mma_8x16
|
|||||||
SAVE_ACC_bz(fv4sf_t, &acc6, rs_c, 8+4*rs_c);
|
SAVE_ACC_bz(fv4sf_t, &acc6, rs_c, 8+4*rs_c);
|
||||||
SAVE_ACC_bz(fv4sf_t, &acc7, rs_c, 12+4*rs_c);
|
SAVE_ACC_bz(fv4sf_t, &acc7, rs_c, 12+4*rs_c);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GEMM_UKR_FLUSH_CT( s );
|
||||||
}
|
}
|
||||||
@@ -42,21 +42,23 @@
|
|||||||
__builtin_mma_xvf16ger2pp (&acc4, ca[1], rb[0]); \
|
__builtin_mma_xvf16ger2pp (&acc4, ca[1], rb[0]); \
|
||||||
__builtin_mma_xvf16ger2pp (&acc5, ca[1], rb[1]); \
|
__builtin_mma_xvf16ger2pp (&acc5, ca[1], rb[1]); \
|
||||||
__builtin_mma_xvf16ger2pp (&acc6, ca[1], rb[2]); \
|
__builtin_mma_xvf16ger2pp (&acc6, ca[1], rb[2]); \
|
||||||
__builtin_mma_xvf16ger2pp (&acc7, ca[1], rb[3]);
|
__builtin_mma_xvf16ger2pp (&acc7, ca[1], rb[3]);
|
||||||
|
|
||||||
#define H_INCREMENT \
|
#define H_INCREMENT \
|
||||||
A0+=16; \
|
A0+=16; \
|
||||||
B0+=32;
|
B0+=32;
|
||||||
|
|
||||||
#define H_AB_PRODUCT \
|
#define H_AB_PRODUCT \
|
||||||
LOAD_VECTORS \
|
LOAD_VECTORS \
|
||||||
H_INCREMENT \
|
H_INCREMENT \
|
||||||
H_ACCUMULATE
|
H_ACCUMULATE
|
||||||
|
|
||||||
|
|
||||||
void bli_shgemm_power10_mma_8x16
|
void bli_shgemm_power10_mma_8x16
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
float* restrict alpha,
|
float* restrict alpha,
|
||||||
float16* restrict a,
|
float16* restrict a,
|
||||||
float16* restrict b,
|
float16* restrict b,
|
||||||
@@ -67,8 +69,8 @@ void bli_shgemm_power10_mma_8x16
|
|||||||
)
|
)
|
||||||
{
|
{
|
||||||
|
|
||||||
uint64_t k_iter = (k0-1)/4;
|
uint64_t k_iter = (k-1)/4;
|
||||||
uint64_t k_left = (k0-1)%4;
|
uint64_t k_left = (k-1)%4;
|
||||||
|
|
||||||
uint64_t rs_c = rs_c0;
|
uint64_t rs_c = rs_c0;
|
||||||
|
|
||||||
@@ -83,7 +85,7 @@ void bli_shgemm_power10_mma_8x16
|
|||||||
fv4sf_t *rowC;
|
fv4sf_t *rowC;
|
||||||
|
|
||||||
// accumulators that will hold the matrix product
|
// accumulators that will hold the matrix product
|
||||||
__vector_quad acc0, acc1, acc2, acc3,
|
__vector_quad acc0, acc1, acc2, acc3,
|
||||||
acc4, acc5, acc6, acc7;
|
acc4, acc5, acc6, acc7;
|
||||||
|
|
||||||
vec_t *ca = (vec_t *) A0;
|
vec_t *ca = (vec_t *) A0;
|
||||||
|
|||||||
@@ -50,32 +50,28 @@
|
|||||||
*/
|
*/
|
||||||
void bli_sgemm_power7_int_8x4
|
void bli_sgemm_power7_int_8x4
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
float* restrict alpha,
|
float* restrict alpha,
|
||||||
float* restrict a,
|
float* restrict a,
|
||||||
float* restrict b,
|
float* restrict b,
|
||||||
float* restrict beta,
|
float* restrict beta,
|
||||||
float* restrict c, inc_t rs_c0, inc_t cs_c0,
|
float* restrict c, inc_t rs_c, inc_t cs_c,
|
||||||
auxinfo_t* restrict data,
|
auxinfo_t* restrict data,
|
||||||
cntx_t* restrict cntx
|
cntx_t* restrict cntx
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
|
||||||
// different size than is expected by load instructions.
|
|
||||||
uint64_t k = k0;
|
|
||||||
uint64_t rs_c = rs_c0;
|
|
||||||
uint64_t cs_c = cs_c0;
|
|
||||||
|
|
||||||
#if 1 || defined(UTEST)
|
#if 1 || defined(UTEST)
|
||||||
const long MR = BLIS_DEFAULT_MR_S, NR = BLIS_DEFAULT_NR_S;
|
const long MR = BLIS_DEFAULT_MR_S, NR = BLIS_DEFAULT_NR_S;
|
||||||
const long LDA = MR, LDB = NR;
|
const long LDA = MR, LDB = NR;
|
||||||
long i, j, kk;
|
long i, j, kk;
|
||||||
float c00;
|
float c00;
|
||||||
|
|
||||||
for (i=0; i < MR; i++) {
|
for (i=0; i < m; i++) {
|
||||||
for (j=0; j < NR; j++) {
|
for (j=0; j < n; j++) {
|
||||||
c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta;
|
c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta;
|
||||||
for (kk=0; kk < k; kk++)
|
for (kk=0; kk < k; kk++)
|
||||||
c00 += *alpha * (a[COLMAJ_INDEX(i,kk,LDA)] * b[ROWMAJ_INDEX(kk,j,LDB)]);
|
c00 += *alpha * (a[COLMAJ_INDEX(i,kk,LDA)] * b[ROWMAJ_INDEX(kk,j,LDB)]);
|
||||||
c[BLIS_INDEX(i,j,rs_c,cs_c)] = c00;
|
c[BLIS_INDEX(i,j,rs_c,cs_c)] = c00;
|
||||||
}
|
}
|
||||||
@@ -96,24 +92,160 @@ void bli_sgemm_power7_int_8x4
|
|||||||
*/
|
*/
|
||||||
void bli_dgemm_power7_int_8x4
|
void bli_dgemm_power7_int_8x4
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
double* restrict alpha,
|
double* restrict alpha,
|
||||||
double* restrict a,
|
double* restrict a,
|
||||||
double* restrict b,
|
double* restrict b,
|
||||||
double* restrict beta,
|
double* restrict beta,
|
||||||
double* restrict c, inc_t rs_c0, inc_t cs_c0,
|
double* restrict c, inc_t rs_c, inc_t cs_c,
|
||||||
auxinfo_t* restrict data,
|
auxinfo_t* restrict data,
|
||||||
cntx_t* restrict cntx
|
cntx_t* restrict cntx
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
if ( cs_c == 1 )
|
||||||
// different size than is expected by load instructions.
|
{
|
||||||
uint64_t k = k0;
|
// Optimized code for case where C rows are contiguous (i.e. C is row-major)
|
||||||
uint64_t rs_c = rs_c0;
|
|
||||||
uint64_t cs_c = cs_c0;
|
vector double vzero = vec_splats( 0.0 );
|
||||||
|
|
||||||
|
vector double vc00_01 = vzero;
|
||||||
|
vector double vc02_03 = vzero;
|
||||||
|
vector double vc10_11 = vzero;
|
||||||
|
vector double vc12_13 = vzero;
|
||||||
|
vector double vc20_21 = vzero;
|
||||||
|
vector double vc22_23 = vzero;
|
||||||
|
vector double vc30_31 = vzero;
|
||||||
|
vector double vc32_33 = vzero;
|
||||||
|
vector double vc40_41 = vzero;
|
||||||
|
vector double vc42_43 = vzero;
|
||||||
|
vector double vc50_51 = vzero;
|
||||||
|
vector double vc52_53 = vzero;
|
||||||
|
vector double vc60_61 = vzero;
|
||||||
|
vector double vc62_63 = vzero;
|
||||||
|
vector double vc70_71 = vzero;
|
||||||
|
vector double vc72_73 = vzero;
|
||||||
|
|
||||||
|
unsigned long long pa = (unsigned long long)a;
|
||||||
|
unsigned long long pb = (unsigned long long)b;
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
unsigned long long d1 = 1*sizeof(double);
|
||||||
|
unsigned long long d2 = 2*sizeof(double);
|
||||||
|
unsigned long long d3 = 3*sizeof(double);
|
||||||
|
unsigned long long d4 = 4*sizeof(double);
|
||||||
|
unsigned long long d6 = 6*sizeof(double);
|
||||||
|
#else
|
||||||
|
// ppc64 linux abi: r14-r31 Nonvolatile registers used for local variables
|
||||||
|
register unsigned long long d1 __asm ("r21") = 1*sizeof(double);
|
||||||
|
register unsigned long long d2 __asm ("r22") = 2*sizeof(double);
|
||||||
|
register unsigned long long d3 __asm ("r23") = 3*sizeof(double);
|
||||||
|
register unsigned long long d4 __asm ("r24") = 4*sizeof(double);
|
||||||
|
register unsigned long long d5 __asm ("r25") = 5*sizeof(double);
|
||||||
|
register unsigned long long d6 __asm ("r26") = 6*sizeof(double);
|
||||||
|
register unsigned long long d7 __asm ("r27") = 7*sizeof(double);
|
||||||
|
|
||||||
|
__asm__ volatile (";" : "=r" (d1) : "r" (d1) );
|
||||||
|
__asm__ volatile (";" : "=r" (d2) : "r" (d2) );
|
||||||
|
__asm__ volatile (";" : "=r" (d3) : "r" (d3) );
|
||||||
|
__asm__ volatile (";" : "=r" (d4) : "r" (d4) );
|
||||||
|
__asm__ volatile (";" : "=r" (d5) : "r" (d5) );
|
||||||
|
__asm__ volatile (";" : "=r" (d6) : "r" (d6) );
|
||||||
|
__asm__ volatile (";" : "=r" (d7) : "r" (d7) );
|
||||||
|
#endif
|
||||||
|
|
||||||
|
int kk;
|
||||||
|
for (kk=k; kk > 0; kk--) {
|
||||||
|
vector double va00 = vec_splats( *(double *)( pa+0 ) );
|
||||||
|
vector double va10 = vec_splats( *(double *)( pa+d1 ) );
|
||||||
|
vector double va20 = vec_splats( *(double *)( pa+d2 ) );
|
||||||
|
vector double va30 = vec_splats( *(double *)( pa+d3 ) );
|
||||||
|
vector double va40 = vec_splats( *(double *)( pa+d4 ) );
|
||||||
|
vector double va50 = vec_splats( *(double *)( pa+d5 ) );
|
||||||
|
vector double va60 = vec_splats( *(double *)( pa+d6 ) );
|
||||||
|
vector double va70 = vec_splats( *(double *)( pa+d7 ) );
|
||||||
|
pa += 8*sizeof(double);
|
||||||
|
|
||||||
|
vector double vb00_01 = *(vector double *)( pb+0 );
|
||||||
|
vector double vb02_03 = *(vector double *)( pb+d2 );
|
||||||
|
pb += 4*sizeof(double);
|
||||||
|
|
||||||
|
vc00_01 = vec_madd(va00, vb00_01, vc00_01);
|
||||||
|
vc02_03 = vec_madd(va00, vb02_03, vc02_03);
|
||||||
|
vc10_11 = vec_madd(va10, vb00_01, vc10_11);
|
||||||
|
vc12_13 = vec_madd(va10, vb02_03, vc12_13);
|
||||||
|
vc20_21 = vec_madd(va20, vb00_01, vc20_21);
|
||||||
|
vc22_23 = vec_madd(va20, vb02_03, vc22_23);
|
||||||
|
vc30_31 = vec_madd(va30, vb00_01, vc30_31);
|
||||||
|
vc32_33 = vec_madd(va30, vb02_03, vc32_33);
|
||||||
|
vc40_41 = vec_madd(va40, vb00_01, vc40_41);
|
||||||
|
vc42_43 = vec_madd(va40, vb02_03, vc42_43);
|
||||||
|
vc50_51 = vec_madd(va50, vb00_01, vc50_51);
|
||||||
|
vc52_53 = vec_madd(va50, vb02_03, vc52_53);
|
||||||
|
vc60_61 = vec_madd(va60, vb00_01, vc60_61);
|
||||||
|
vc62_63 = vec_madd(va60, vb02_03, vc62_63);
|
||||||
|
vc70_71 = vec_madd(va70, vb00_01, vc70_71);
|
||||||
|
vc72_73 = vec_madd(va70, vb02_03, vc72_73);
|
||||||
|
}
|
||||||
|
|
||||||
|
vector double valpha = vec_splats( *alpha );
|
||||||
|
vector double vbeta = (vector double) { *beta, *beta };
|
||||||
|
|
||||||
|
vector double *pc = (vector double *)c;
|
||||||
|
|
||||||
|
vc00_01 = vec_mul(valpha, vc00_01);
|
||||||
|
vc02_03 = vec_mul(valpha, vc02_03);
|
||||||
|
pc[0] = vec_madd( pc[0], vbeta, vc00_01);
|
||||||
|
pc[1] = vec_madd( pc[1], vbeta, vc02_03);
|
||||||
|
pc += rs_c/2;
|
||||||
|
|
||||||
|
vc10_11 = vec_mul(valpha, vc10_11);
|
||||||
|
vc12_13 = vec_mul(valpha, vc12_13);
|
||||||
|
pc[0] = vec_madd( pc[0], vbeta, vc10_11);
|
||||||
|
pc[1] = vec_madd( pc[1], vbeta, vc12_13);
|
||||||
|
pc += rs_c/2;
|
||||||
|
|
||||||
|
vc20_21 = vec_mul(valpha, vc20_21);
|
||||||
|
vc22_23 = vec_mul(valpha, vc22_23);
|
||||||
|
pc[0] = vec_madd( pc[0], vbeta, vc20_21);
|
||||||
|
pc[1] = vec_madd( pc[1], vbeta, vc22_23);
|
||||||
|
pc += rs_c/2;
|
||||||
|
|
||||||
|
vc30_31 = vec_mul(valpha, vc30_31);
|
||||||
|
vc32_33 = vec_mul(valpha, vc32_33);
|
||||||
|
pc[0] = vec_madd( pc[0], vbeta, vc30_31);
|
||||||
|
pc[1] = vec_madd( pc[1], vbeta, vc32_33);
|
||||||
|
pc += rs_c/2;
|
||||||
|
|
||||||
|
vc40_41 = vec_mul(valpha, vc40_41);
|
||||||
|
vc42_43 = vec_mul(valpha, vc42_43);
|
||||||
|
pc[0] = vec_madd( pc[0], vbeta, vc40_41);
|
||||||
|
pc[1] = vec_madd( pc[1], vbeta, vc42_43);
|
||||||
|
pc += rs_c/2;
|
||||||
|
|
||||||
|
vc50_51 = vec_mul(valpha, vc50_51);
|
||||||
|
vc52_53 = vec_mul(valpha, vc52_53);
|
||||||
|
pc[0] = vec_madd( pc[0], vbeta, vc50_51);
|
||||||
|
pc[1] = vec_madd( pc[1], vbeta, vc52_53);
|
||||||
|
pc += rs_c/2;
|
||||||
|
|
||||||
|
vc60_61 = vec_mul(valpha, vc60_61);
|
||||||
|
vc62_63 = vec_mul(valpha, vc62_63);
|
||||||
|
pc[0] = vec_madd( pc[0], vbeta, vc60_61);
|
||||||
|
pc[1] = vec_madd( pc[1], vbeta, vc62_63);
|
||||||
|
pc += rs_c/2;
|
||||||
|
|
||||||
|
vc70_71 = vec_mul(valpha, vc70_71);
|
||||||
|
vc72_73 = vec_mul(valpha, vc72_73);
|
||||||
|
pc[0] = vec_madd( pc[0], vbeta, vc70_71);
|
||||||
|
pc[1] = vec_madd( pc[1], vbeta, vc72_73);
|
||||||
|
pc += rs_c/2;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
GEMM_UKR_SETUP_CT( d, 8, 4, false );
|
||||||
|
|
||||||
#if 1
|
|
||||||
if (rs_c == 1) {
|
|
||||||
// Optimized code for case where C columns are contiguous (column-major C)
|
// Optimized code for case where C columns are contiguous (column-major C)
|
||||||
vector double vzero = vec_splats( 0.0 );
|
vector double vzero = vec_splats( 0.0 );
|
||||||
|
|
||||||
@@ -301,168 +433,8 @@ void bli_dgemm_power7_int_8x4
|
|||||||
pc[1] = vec_madd( pc[1], vbeta, vc23_33);
|
pc[1] = vec_madd( pc[1], vbeta, vc23_33);
|
||||||
pc[2] = vec_madd( pc[2], vbeta, vc43_53);
|
pc[2] = vec_madd( pc[2], vbeta, vc43_53);
|
||||||
pc[3] = vec_madd( pc[3], vbeta, vc63_73);
|
pc[3] = vec_madd( pc[3], vbeta, vc63_73);
|
||||||
}
|
|
||||||
else
|
|
||||||
#endif
|
|
||||||
#if 1
|
|
||||||
if ( cs_c == 1 ) {
|
|
||||||
// Optimized code for case where C rows are contiguous (i.e. C is row-major)
|
|
||||||
|
|
||||||
vector double vzero = vec_splats( 0.0 );
|
GEMM_UKR_FLUSH_CT( d );
|
||||||
|
|
||||||
vector double vc00_01 = vzero;
|
|
||||||
vector double vc02_03 = vzero;
|
|
||||||
vector double vc10_11 = vzero;
|
|
||||||
vector double vc12_13 = vzero;
|
|
||||||
vector double vc20_21 = vzero;
|
|
||||||
vector double vc22_23 = vzero;
|
|
||||||
vector double vc30_31 = vzero;
|
|
||||||
vector double vc32_33 = vzero;
|
|
||||||
vector double vc40_41 = vzero;
|
|
||||||
vector double vc42_43 = vzero;
|
|
||||||
vector double vc50_51 = vzero;
|
|
||||||
vector double vc52_53 = vzero;
|
|
||||||
vector double vc60_61 = vzero;
|
|
||||||
vector double vc62_63 = vzero;
|
|
||||||
vector double vc70_71 = vzero;
|
|
||||||
vector double vc72_73 = vzero;
|
|
||||||
|
|
||||||
unsigned long long pa = (unsigned long long)a;
|
|
||||||
unsigned long long pb = (unsigned long long)b;
|
|
||||||
|
|
||||||
#if 0
|
|
||||||
unsigned long long d1 = 1*sizeof(double);
|
|
||||||
unsigned long long d2 = 2*sizeof(double);
|
|
||||||
unsigned long long d3 = 3*sizeof(double);
|
|
||||||
unsigned long long d4 = 4*sizeof(double);
|
|
||||||
unsigned long long d6 = 6*sizeof(double);
|
|
||||||
#else
|
|
||||||
// ppc64 linux abi: r14-r31 Nonvolatile registers used for local variables
|
|
||||||
register unsigned long long d1 __asm ("r21") = 1*sizeof(double);
|
|
||||||
register unsigned long long d2 __asm ("r22") = 2*sizeof(double);
|
|
||||||
register unsigned long long d3 __asm ("r23") = 3*sizeof(double);
|
|
||||||
register unsigned long long d4 __asm ("r24") = 4*sizeof(double);
|
|
||||||
register unsigned long long d5 __asm ("r25") = 5*sizeof(double);
|
|
||||||
register unsigned long long d6 __asm ("r26") = 6*sizeof(double);
|
|
||||||
register unsigned long long d7 __asm ("r27") = 7*sizeof(double);
|
|
||||||
|
|
||||||
__asm__ volatile (";" : "=r" (d1) : "r" (d1) );
|
|
||||||
__asm__ volatile (";" : "=r" (d2) : "r" (d2) );
|
|
||||||
__asm__ volatile (";" : "=r" (d3) : "r" (d3) );
|
|
||||||
__asm__ volatile (";" : "=r" (d4) : "r" (d4) );
|
|
||||||
__asm__ volatile (";" : "=r" (d5) : "r" (d5) );
|
|
||||||
__asm__ volatile (";" : "=r" (d6) : "r" (d6) );
|
|
||||||
__asm__ volatile (";" : "=r" (d7) : "r" (d7) );
|
|
||||||
#endif
|
|
||||||
|
|
||||||
int kk;
|
|
||||||
for (kk=k; kk > 0; kk--) {
|
|
||||||
vector double va00 = vec_splats( *(double *)( pa+0 ) );
|
|
||||||
vector double va10 = vec_splats( *(double *)( pa+d1 ) );
|
|
||||||
vector double va20 = vec_splats( *(double *)( pa+d2 ) );
|
|
||||||
vector double va30 = vec_splats( *(double *)( pa+d3 ) );
|
|
||||||
vector double va40 = vec_splats( *(double *)( pa+d4 ) );
|
|
||||||
vector double va50 = vec_splats( *(double *)( pa+d5 ) );
|
|
||||||
vector double va60 = vec_splats( *(double *)( pa+d6 ) );
|
|
||||||
vector double va70 = vec_splats( *(double *)( pa+d7 ) );
|
|
||||||
pa += 8*sizeof(double);
|
|
||||||
|
|
||||||
vector double vb00_01 = *(vector double *)( pb+0 );
|
|
||||||
vector double vb02_03 = *(vector double *)( pb+d2 );
|
|
||||||
pb += 4*sizeof(double);
|
|
||||||
|
|
||||||
vc00_01 = vec_madd(va00, vb00_01, vc00_01);
|
|
||||||
vc02_03 = vec_madd(va00, vb02_03, vc02_03);
|
|
||||||
vc10_11 = vec_madd(va10, vb00_01, vc10_11);
|
|
||||||
vc12_13 = vec_madd(va10, vb02_03, vc12_13);
|
|
||||||
vc20_21 = vec_madd(va20, vb00_01, vc20_21);
|
|
||||||
vc22_23 = vec_madd(va20, vb02_03, vc22_23);
|
|
||||||
vc30_31 = vec_madd(va30, vb00_01, vc30_31);
|
|
||||||
vc32_33 = vec_madd(va30, vb02_03, vc32_33);
|
|
||||||
vc40_41 = vec_madd(va40, vb00_01, vc40_41);
|
|
||||||
vc42_43 = vec_madd(va40, vb02_03, vc42_43);
|
|
||||||
vc50_51 = vec_madd(va50, vb00_01, vc50_51);
|
|
||||||
vc52_53 = vec_madd(va50, vb02_03, vc52_53);
|
|
||||||
vc60_61 = vec_madd(va60, vb00_01, vc60_61);
|
|
||||||
vc62_63 = vec_madd(va60, vb02_03, vc62_63);
|
|
||||||
vc70_71 = vec_madd(va70, vb00_01, vc70_71);
|
|
||||||
vc72_73 = vec_madd(va70, vb02_03, vc72_73);
|
|
||||||
}
|
|
||||||
|
|
||||||
vector double valpha = vec_splats( *alpha );
|
|
||||||
vector double vbeta = (vector double) { *beta, *beta };
|
|
||||||
|
|
||||||
vector double *pc = (vector double *)c;
|
|
||||||
|
|
||||||
vc00_01 = vec_mul(valpha, vc00_01);
|
|
||||||
vc02_03 = vec_mul(valpha, vc02_03);
|
|
||||||
pc[0] = vec_madd( pc[0], vbeta, vc00_01);
|
|
||||||
pc[1] = vec_madd( pc[1], vbeta, vc02_03);
|
|
||||||
pc += rs_c/2;
|
|
||||||
|
|
||||||
vc10_11 = vec_mul(valpha, vc10_11);
|
|
||||||
vc12_13 = vec_mul(valpha, vc12_13);
|
|
||||||
pc[0] = vec_madd( pc[0], vbeta, vc10_11);
|
|
||||||
pc[1] = vec_madd( pc[1], vbeta, vc12_13);
|
|
||||||
pc += rs_c/2;
|
|
||||||
|
|
||||||
vc20_21 = vec_mul(valpha, vc20_21);
|
|
||||||
vc22_23 = vec_mul(valpha, vc22_23);
|
|
||||||
pc[0] = vec_madd( pc[0], vbeta, vc20_21);
|
|
||||||
pc[1] = vec_madd( pc[1], vbeta, vc22_23);
|
|
||||||
pc += rs_c/2;
|
|
||||||
|
|
||||||
vc30_31 = vec_mul(valpha, vc30_31);
|
|
||||||
vc32_33 = vec_mul(valpha, vc32_33);
|
|
||||||
pc[0] = vec_madd( pc[0], vbeta, vc30_31);
|
|
||||||
pc[1] = vec_madd( pc[1], vbeta, vc32_33);
|
|
||||||
pc += rs_c/2;
|
|
||||||
|
|
||||||
vc40_41 = vec_mul(valpha, vc40_41);
|
|
||||||
vc42_43 = vec_mul(valpha, vc42_43);
|
|
||||||
pc[0] = vec_madd( pc[0], vbeta, vc40_41);
|
|
||||||
pc[1] = vec_madd( pc[1], vbeta, vc42_43);
|
|
||||||
pc += rs_c/2;
|
|
||||||
|
|
||||||
vc50_51 = vec_mul(valpha, vc50_51);
|
|
||||||
vc52_53 = vec_mul(valpha, vc52_53);
|
|
||||||
pc[0] = vec_madd( pc[0], vbeta, vc50_51);
|
|
||||||
pc[1] = vec_madd( pc[1], vbeta, vc52_53);
|
|
||||||
pc += rs_c/2;
|
|
||||||
|
|
||||||
vc60_61 = vec_mul(valpha, vc60_61);
|
|
||||||
vc62_63 = vec_mul(valpha, vc62_63);
|
|
||||||
pc[0] = vec_madd( pc[0], vbeta, vc60_61);
|
|
||||||
pc[1] = vec_madd( pc[1], vbeta, vc62_63);
|
|
||||||
pc += rs_c/2;
|
|
||||||
|
|
||||||
vc70_71 = vec_mul(valpha, vc70_71);
|
|
||||||
vc72_73 = vec_mul(valpha, vc72_73);
|
|
||||||
pc[0] = vec_madd( pc[0], vbeta, vc70_71);
|
|
||||||
pc[1] = vec_madd( pc[1], vbeta, vc72_73);
|
|
||||||
pc += rs_c/2;
|
|
||||||
|
|
||||||
}
|
|
||||||
else
|
|
||||||
#endif
|
|
||||||
{ /* General case. Just do it right. */
|
|
||||||
#if 1 || defined(UTEST)
|
|
||||||
const long MR = BLIS_DEFAULT_MR_D, NR = BLIS_DEFAULT_NR_D;
|
|
||||||
const long LDA = MR, LDB = NR;
|
|
||||||
int i, j, kk;
|
|
||||||
double c00;
|
|
||||||
|
|
||||||
for (i=0; i < MR; i++) {
|
|
||||||
for (j=0; j < NR; j++) {
|
|
||||||
c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta;
|
|
||||||
for (kk=0; kk < k; kk++)
|
|
||||||
c00 += *alpha * (a[COLMAJ_INDEX(i,kk,LDA)] * b[ROWMAJ_INDEX(kk,j,LDB)]);
|
|
||||||
c[BLIS_INDEX(i,j,rs_c,cs_c)] = c00;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
//BLIS_DGEMM_UKERNEL_REF(k, alpha, a, b, beta, c, rs_c, cs_c, data);
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -477,30 +449,26 @@ void bli_dgemm_power7_int_8x4
|
|||||||
*/
|
*/
|
||||||
void bli_cgemm_power7_int_8x4
|
void bli_cgemm_power7_int_8x4
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
scomplex* restrict alpha,
|
scomplex* restrict alpha,
|
||||||
scomplex* restrict a,
|
scomplex* restrict a,
|
||||||
scomplex* restrict b,
|
scomplex* restrict b,
|
||||||
scomplex* restrict beta,
|
scomplex* restrict beta,
|
||||||
scomplex* restrict c, inc_t rs_c0, inc_t cs_c0,
|
scomplex* restrict c, inc_t rs_c, inc_t cs_c,
|
||||||
auxinfo_t* restrict data,
|
auxinfo_t* restrict data,
|
||||||
cntx_t* restrict cntx
|
cntx_t* restrict cntx
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
|
||||||
// different size than is expected by load instructions.
|
|
||||||
uint64_t k = k0;
|
|
||||||
uint64_t rs_c = rs_c0;
|
|
||||||
uint64_t cs_c = cs_c0;
|
|
||||||
|
|
||||||
#if 1 || defined(UTEST)
|
#if 1 || defined(UTEST)
|
||||||
const long MR = BLIS_DEFAULT_MR_C, NR = BLIS_DEFAULT_NR_C;
|
const long MR = BLIS_DEFAULT_MR_C, NR = BLIS_DEFAULT_NR_C;
|
||||||
const long LDA = MR, LDB = NR;
|
const long LDA = MR, LDB = NR;
|
||||||
int i, j, kk;
|
int i, j, kk;
|
||||||
scomplex c00;
|
scomplex c00;
|
||||||
|
|
||||||
for (i=0; i < MR; i++) {
|
for (i=0; i < m; i++) {
|
||||||
for (j=0; j < NR; j++) {
|
for (j=0; j < n; j++) {
|
||||||
scomplex tmpc, tmpa, tmpb, tmp;
|
scomplex tmpc, tmpa, tmpb, tmp;
|
||||||
//c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta;
|
//c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta;
|
||||||
tmpc = c[BLIS_INDEX(i,j,rs_c,cs_c)];
|
tmpc = c[BLIS_INDEX(i,j,rs_c,cs_c)];
|
||||||
@@ -534,30 +502,26 @@ void bli_cgemm_power7_int_8x4
|
|||||||
*/
|
*/
|
||||||
void bli_zgemm_power7_int_8x4
|
void bli_zgemm_power7_int_8x4
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
scomplex* restrict alpha,
|
scomplex* restrict alpha,
|
||||||
scomplex* restrict a,
|
scomplex* restrict a,
|
||||||
scomplex* restrict b,
|
scomplex* restrict b,
|
||||||
scomplex* restrict beta,
|
scomplex* restrict beta,
|
||||||
scomplex* restrict c, inc_t rs_c0, inc_t cs_c0,
|
scomplex* restrict c, inc_t rs_c, inc_t cs_c,
|
||||||
auxinfo_t* restrict data,
|
auxinfo_t* restrict data,
|
||||||
cntx_t* restrict cntx
|
cntx_t* restrict cntx
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
|
||||||
// different size than is expected by load instructions.
|
|
||||||
uint64_t k = k0;
|
|
||||||
uint64_t rs_c = rs_c0;
|
|
||||||
uint64_t cs_c = cs_c0;
|
|
||||||
|
|
||||||
#if 1 || defined(UTEST)
|
#if 1 || defined(UTEST)
|
||||||
const long MR = BLIS_DEFAULT_MR_Z, NR = BLIS_DEFAULT_NR_Z;
|
const long MR = BLIS_DEFAULT_MR_Z, NR = BLIS_DEFAULT_NR_Z;
|
||||||
const long LDA = MR, LDB = NR;
|
const long LDA = MR, LDB = NR;
|
||||||
int i, j, kk;
|
int i, j, kk;
|
||||||
dcomplex c00;
|
dcomplex c00;
|
||||||
|
|
||||||
for (i=0; i < MR; i++) {
|
for (i=0; i < m; i++) {
|
||||||
for (j=0; j < NR; j++) {
|
for (j=0; j < n; j++) {
|
||||||
dcomplex tmpc, tmpa, tmpb, tmp;
|
dcomplex tmpc, tmpa, tmpb, tmp;
|
||||||
//c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta;
|
//c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta;
|
||||||
tmpc = c[BLIS_INDEX(i,j,rs_c,cs_c)];
|
tmpc = c[BLIS_INDEX(i,j,rs_c,cs_c)];
|
||||||
|
|||||||
@@ -43,6 +43,8 @@
|
|||||||
|
|
||||||
void bli_sgemm_opt_8x4
|
void bli_sgemm_opt_8x4
|
||||||
(
|
(
|
||||||
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
dim_t k,
|
dim_t k,
|
||||||
float* restrict alpha,
|
float* restrict alpha,
|
||||||
float* restrict a,
|
float* restrict a,
|
||||||
@@ -55,6 +57,8 @@ void bli_sgemm_opt_8x4
|
|||||||
|
|
||||||
void bli_dgemm_opt_8x4
|
void bli_dgemm_opt_8x4
|
||||||
(
|
(
|
||||||
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
dim_t k,
|
dim_t k,
|
||||||
double* restrict alpha,
|
double* restrict alpha,
|
||||||
double* restrict a,
|
double* restrict a,
|
||||||
@@ -67,6 +71,8 @@ void bli_dgemm_opt_8x4
|
|||||||
|
|
||||||
void bli_cgemm_opt_8x4
|
void bli_cgemm_opt_8x4
|
||||||
(
|
(
|
||||||
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
dim_t k,
|
dim_t k,
|
||||||
scomplex* restrict alpha,
|
scomplex* restrict alpha,
|
||||||
scomplex* restrict a,
|
scomplex* restrict a,
|
||||||
@@ -79,6 +85,8 @@ void bli_cgemm_opt_8x4
|
|||||||
|
|
||||||
void bli_zgemm_opt_8x4
|
void bli_zgemm_opt_8x4
|
||||||
(
|
(
|
||||||
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
dim_t k,
|
dim_t k,
|
||||||
dcomplex* restrict alpha,
|
dcomplex* restrict alpha,
|
||||||
dcomplex* restrict a,
|
dcomplex* restrict a,
|
||||||
|
|||||||
@@ -37,7 +37,9 @@
|
|||||||
|
|
||||||
void bli_dgemm_power9_asm_12x6
|
void bli_dgemm_power9_asm_12x6
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
double* restrict alpha,
|
double* restrict alpha,
|
||||||
double* restrict a,
|
double* restrict a,
|
||||||
double* restrict b,
|
double* restrict b,
|
||||||
@@ -50,117 +52,91 @@ void bli_dgemm_power9_asm_12x6
|
|||||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||||
// different size than is expected by load instructions.
|
// different size than is expected by load instructions.
|
||||||
|
|
||||||
uint64_t k_iter = k0 / 16;
|
uint64_t k_iter = k / 16;
|
||||||
uint64_t k_left = k0 % 16;
|
uint64_t k_left = k % 16;
|
||||||
|
|
||||||
uint64_t rs_c = rs_c0;
|
uint64_t rs_c = rs_c0;
|
||||||
uint64_t cs_c = cs_c0;
|
uint64_t cs_c = cs_c0;
|
||||||
|
|
||||||
|
GEMM_UKR_SETUP_CT( d, 12, 6, false );
|
||||||
|
|
||||||
__asm__ volatile
|
__asm__ volatile
|
||||||
(
|
(
|
||||||
" \n\t"
|
" \n\t"
|
||||||
"ld %%r7, %2 \n\t" // load ptr of A
|
"ld %%r7, %2 \n\t" // load ptr of A
|
||||||
"ld %%r8, %3 \n\t" // load ptr of B
|
"ld %%r8, %3 \n\t" // load ptr of B
|
||||||
"ld %%r16, %6 \n\t" // load ptr of C
|
"ld %%r16, %6 \n\t" // load ptr of C
|
||||||
" \n\t"
|
" \n\t"
|
||||||
"ld %%r28, %4 \n\t" // load ptr for alpha
|
"ld %%r28, %4 \n\t" // load ptr for alpha
|
||||||
"ld %%r29, %5 \n\t" // load ptr for beta
|
"ld %%r29, %5 \n\t" // load ptr for beta
|
||||||
" \n\t"
|
" \n\t"
|
||||||
"ld %%r11, %0 \n\t" // load k_iter
|
"ld %%r11, %0 \n\t" // load k_iter
|
||||||
"ld %%r12, %1 \n\t" // load k_left
|
"ld %%r12, %1 \n\t" // load k_left
|
||||||
" \n\t"
|
" \n\t"
|
||||||
"ld %%r10, %8 \n\t" // load cs_c
|
"ld %%r10, %8 \n\t" // load cs_c
|
||||||
"slwi %%r10, %%r10, 3 \n\t" // mul by size of elem
|
"slwi %%r10, %%r10, 3 \n\t" // mul by size of elem
|
||||||
" \n\t"
|
" \n\t"
|
||||||
"ld %%r9, %7 \n\t" // load rs_c
|
"ld %%r9, %7 \n\t" // load rs_c
|
||||||
"slwi %%r9, %%r9, 3 \n\t" // mul by size of elem
|
"slwi %%r9, %%r9, 3 \n\t" // mul by size of elem
|
||||||
" \n\t"
|
" \n\t"
|
||||||
"ld %%r26, 0(%%r29) \n\t" // load val of beta
|
"ld %%r26, 0(%%r29) \n\t" // load val of beta
|
||||||
" \n\t"
|
" \n\t"
|
||||||
"lxvdsx %%vs62, 0, %%r28 \n\t" // splat alpha
|
"lxvdsx %%vs62, 0, %%r28 \n\t" // splat alpha
|
||||||
"lxvdsx %%vs63, 0, %%r29 \n\t" // splat beta
|
"lxvdsx %%vs63, 0, %%r29 \n\t" // splat beta
|
||||||
" \n\t"
|
" \n\t"
|
||||||
"add %%r17, %%r16, %%r10 \n\t" // addr of col 1 of C
|
"add %%r17, %%r16, %%r10 \n\t" // addr of col 1 of C
|
||||||
"add %%r18, %%r17, %%r10 \n\t" // col 2 of C
|
"add %%r18, %%r17, %%r10 \n\t" // col 2 of C
|
||||||
"add %%r19, %%r18, %%r10 \n\t" // col 3 of C
|
"add %%r19, %%r18, %%r10 \n\t" // col 3 of C
|
||||||
"add %%r20, %%r19, %%r10 \n\t" // col 4 of C
|
"add %%r20, %%r19, %%r10 \n\t" // col 4 of C
|
||||||
"add %%r21, %%r20, %%r10 \n\t" // col 5 of C
|
"add %%r21, %%r20, %%r10 \n\t" // col 5 of C
|
||||||
" \n\t"
|
" \n\t"
|
||||||
DZERO_OUT_VREG
|
DZERO_OUT_VREG
|
||||||
" \n\t"
|
" \n\t"
|
||||||
DPRELOAD
|
DPRELOAD
|
||||||
" \n\t"
|
" \n\t"
|
||||||
"addi %%r8, %%r8, 96 \n\t" // move to next col/row of A/B
|
"addi %%r8, %%r8, 96 \n\t" // move to next col/row of A/B
|
||||||
"addi %%r7, %%r7, 96 \n\t"
|
"addi %%r7, %%r7, 96 \n\t"
|
||||||
" \n\t"
|
" \n\t"
|
||||||
DPREFETCH
|
DPREFETCH
|
||||||
" \n\t"
|
" \n\t"
|
||||||
"cmpwi %%r11, 0 \n\t" // if k_iter == 0,
|
"cmpwi %%r11, 0 \n\t" // if k_iter == 0,
|
||||||
"beq DCONSIDERKLEFT \n\t" // then jmp to k_left
|
"beq DCONSIDERKLEFT \n\t" // then jmp to k_left
|
||||||
"mtctr %%r11 \n\t" // else, do k_iter loop
|
"mtctr %%r11 \n\t" // else, do k_iter loop
|
||||||
" \n\t"
|
" \n\t"
|
||||||
"DLOOPKITER: \n\t" // k_iter loop
|
"DLOOPKITER: \n\t" // k_iter loop
|
||||||
" \n\t"
|
" \n\t"
|
||||||
A_B_PRODUCT_16 // compute A*B
|
A_B_PRODUCT_16 // compute A*B
|
||||||
" \n\t"
|
" \n\t"
|
||||||
"bdnz DLOOPKITER \n\t"
|
"bdnz DLOOPKITER \n\t"
|
||||||
" \n\t"
|
" \n\t"
|
||||||
"DCONSIDERKLEFT: \n\t"
|
"DCONSIDERKLEFT: \n\t"
|
||||||
" \n\t"
|
" \n\t"
|
||||||
"cmpwi %%r12, 0 \n\t" // if k_left == 0,
|
"cmpwi %%r12, 0 \n\t" // if k_left == 0,
|
||||||
"beq DPOSTACCUM \n\t" // then jmp to post accum
|
"beq DPOSTACCUM \n\t" // then jmp to post accum
|
||||||
"mtctr %%r12 \n\t" // else, do k_left loop
|
"mtctr %%r12 \n\t" // else, do k_left loop
|
||||||
" \n\t"
|
" \n\t"
|
||||||
"DLOOPKLEFT: \n\t" // k_left loop
|
"DLOOPKLEFT: \n\t" // k_left loop
|
||||||
" \n\t"
|
" \n\t"
|
||||||
A_B_PRODUCT_1
|
A_B_PRODUCT_1
|
||||||
" \n\t"
|
" \n\t"
|
||||||
"bdnz DLOOPKLEFT \n\t"
|
"bdnz DLOOPKLEFT \n\t"
|
||||||
" \n\t"
|
" \n\t"
|
||||||
"DPOSTACCUM: \n\t"
|
"DPOSTACCUM: \n\t"
|
||||||
" \n\t"
|
" \n\t"
|
||||||
DSCALE_ALPHA
|
DSCALE_ALPHA
|
||||||
" \n\t"
|
" \n\t"
|
||||||
"cmpdi %%r26, 0 \n\t" // if beta == 0,
|
"cmpdi %%r26, 0 \n\t" // if beta == 0,
|
||||||
"beq DBETAZERO \n\t" // then jmp to BZ
|
"beq DBETAZERO \n\t" // then jmp to BZ
|
||||||
" \n\t"
|
" \n\t"
|
||||||
"cmpwi %%r9, 8 \n\t" // if rs_c == 8
|
DCOL_SCALE_BETA
|
||||||
"beq DCOLSTOREDBNZ \n\t" // then jmp to col store
|
" \n\t"
|
||||||
" \n\t"
|
"DBETAZERO: \n\t" // BZ case
|
||||||
"DGENSTOREDBNZ: \n\t" // BNZ gen stored case
|
" \n\t"
|
||||||
" \n\t"
|
DCOL_STORE
|
||||||
DGEN_LOAD_OFS_C
|
" \n\t"
|
||||||
" \n\t"
|
"DDONE: \n\t"
|
||||||
DGEN_SCALE_BETA
|
" \n\t"
|
||||||
" \n\t"
|
: // output operands (none)
|
||||||
"b DGENSTORED \n\t"
|
|
||||||
" \n\t"
|
|
||||||
"DCOLSTOREDBNZ: \n\t" // BNZ col stored case
|
|
||||||
" \n\t"
|
|
||||||
DCOL_SCALE_BETA
|
|
||||||
" \n\t"
|
|
||||||
"b DCOLSTORED \n\t"
|
|
||||||
" \n\t"
|
|
||||||
"DBETAZERO: \n\t" // BZ case
|
|
||||||
" \n\t"
|
|
||||||
"cmpwi %%r9, 8 \n\t" // if rs_c == 8,
|
|
||||||
"beq DCOLSTORED \n\t" // C is col stored
|
|
||||||
" \n\t"
|
|
||||||
"DGENSTORED: \n\t" // BZ gen stored case
|
|
||||||
" \n\t"
|
|
||||||
DGEN_LOAD_OFS_C
|
|
||||||
" \n\t"
|
|
||||||
DGEN_STORE
|
|
||||||
" \n\t"
|
|
||||||
"b DDONE \n\t"
|
|
||||||
" \n\t"
|
|
||||||
"DCOLSTORED: \n\t" // BZ col stored case
|
|
||||||
" \n\t"
|
|
||||||
DCOL_STORE
|
|
||||||
" \n\t"
|
|
||||||
"DDONE: \n\t"
|
|
||||||
" \n\t"
|
|
||||||
: // output operands (none)
|
|
||||||
: // input operands
|
: // input operands
|
||||||
"m" (k_iter), // 0
|
"m" (k_iter), // 0
|
||||||
"m" (k_left), // 1
|
"m" (k_left), // 1
|
||||||
@@ -174,28 +150,30 @@ void bli_dgemm_power9_asm_12x6
|
|||||||
"m" (b_next), // 9
|
"m" (b_next), // 9
|
||||||
"m" (a_next)*/ // 10
|
"m" (a_next)*/ // 10
|
||||||
: // register clobber list
|
: // register clobber list
|
||||||
/* unclobberable regs: r2, r3, r4, r5, r6, r13, r14, r15, r30, r31 */
|
/* unclobberable regs: r2, r3, r4, r5, r6, r13, r14, r15, r30, r31 */
|
||||||
"r0", "r7", "r8", "r9",
|
"r0", "r7", "r8", "r9",
|
||||||
"r10", "r11", "r12", "r16", "r17", "r18", "r19",
|
"r10", "r11", "r12", "r16", "r17", "r18", "r19",
|
||||||
"r20", "r21", "r22", "r23", "r24", "r25", "r26", "r27", "r28", "r29"
|
"r20", "r21", "r22", "r23", "r24", "r25", "r26", "r27", "r28", "r29"
|
||||||
|
|
||||||
#if XLC
|
#if XLC
|
||||||
,"f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9"
|
,"f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9"
|
||||||
, "f10", "f11", "f12", "f13", "f14", "f15", "f16", "f17", "f18", "f19"
|
, "f10", "f11", "f12", "f13", "f14", "f15", "f16", "f17", "f18", "f19"
|
||||||
, "f20" ,"f21", "f22", "f23", "f24", "f25", "f26", "f27", "f28", "f29"
|
, "f20" ,"f21", "f22", "f23", "f24", "f25", "f26", "f27", "f28", "f29"
|
||||||
, "f30" ,"f31"
|
, "f30" ,"f31"
|
||||||
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9"
|
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9"
|
||||||
, "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19"
|
, "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19"
|
||||||
, "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29"
|
, "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29"
|
||||||
, "v30", "v31"
|
, "v30", "v31"
|
||||||
#else
|
#else
|
||||||
, "vs0", "vs1", "vs2", "vs3", "vs4", "vs5", "vs6", "vs7", "vs8", "vs9"
|
, "vs0", "vs1", "vs2", "vs3", "vs4", "vs5", "vs6", "vs7", "vs8", "vs9"
|
||||||
, "vs10", "vs11", "vs12", "vs13", "vs14", "vs15", "vs16", "vs17", "vs18", "vs19"
|
, "vs10", "vs11", "vs12", "vs13", "vs14", "vs15", "vs16", "vs17", "vs18", "vs19"
|
||||||
, "vs20", "vs21", "vs22", "vs23", "vs24", "vs25", "vs26", "vs27", "vs28", "vs29"
|
, "vs20", "vs21", "vs22", "vs23", "vs24", "vs25", "vs26", "vs27", "vs28", "vs29"
|
||||||
, "vs30", "vs31", "vs32", "vs33", "vs34", "vs35", "vs36", "vs37", "vs38", "vs39"
|
, "vs30", "vs31", "vs32", "vs33", "vs34", "vs35", "vs36", "vs37", "vs38", "vs39"
|
||||||
, "vs40", "vs41", "vs42", "vs43", "vs44", "vs45", "vs46", "vs47", "vs48", "vs49"
|
, "vs40", "vs41", "vs42", "vs43", "vs44", "vs45", "vs46", "vs47", "vs48", "vs49"
|
||||||
, "vs50", "vs51", "vs52", "vs53"
|
, "vs50", "vs51", "vs52", "vs53"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
GEMM_UKR_FLUSH_CT( d );
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -32,14 +32,17 @@
|
|||||||
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <immintrin.h>
|
#include <emmintrin.h>
|
||||||
|
#include <immintrin.h>
|
||||||
#include "blis.h"
|
#include "blis.h"
|
||||||
|
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
void bli_sgemm_sandybridge_int_8x8
|
void bli_sgemm_sandybridge_int_8x8
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
float* restrict alpha,
|
float* restrict alpha,
|
||||||
float* restrict a,
|
float* restrict a,
|
||||||
float* restrict b,
|
float* restrict b,
|
||||||
@@ -52,11 +55,11 @@ void bli_sgemm_sandybridge_int_8x8
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void bli_dgemm_sandybridge_int_8x4
|
void bli_dgemm_sandybridge_int_8x4
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
double* restrict alpha,
|
double* restrict alpha,
|
||||||
double* restrict a,
|
double* restrict a,
|
||||||
double* restrict b,
|
double* restrict b,
|
||||||
@@ -66,19 +69,22 @@ void bli_dgemm_sandybridge_int_8x4
|
|||||||
cntx_t* restrict cntx
|
cntx_t* restrict cntx
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
|
|
||||||
//void* a_next = bli_auxinfo_next_a( data );
|
//void* a_next = bli_auxinfo_next_a( data );
|
||||||
void* b_next = bli_auxinfo_next_b( data );
|
void* b_next = bli_auxinfo_next_b( data );
|
||||||
|
|
||||||
// Typecast local copies of integers in case dim_t and inc_t are a
|
// Typecast local copies of integers in case dim_t and inc_t are a
|
||||||
// different size than is expected by load instructions.
|
// different size than is expected by load instructions.
|
||||||
uint64_t k_iter = k0 / 2;
|
uint64_t k_iter = k / 2;
|
||||||
uint64_t k_left = k0 % 2;
|
uint64_t k_left = k % 2;
|
||||||
uint64_t rs_c = rs_c0;
|
uint64_t rs_c = rs_c0;
|
||||||
uint64_t cs_c = cs_c0;
|
uint64_t cs_c = cs_c0;
|
||||||
uint64_t i;
|
uint64_t i;
|
||||||
|
|
||||||
double *c00, *c01, *c02, *c03;
|
GEMM_UKR_SETUP_CT( d, 8, 4, false );
|
||||||
double *c40, *c41, *c42, *c43;
|
|
||||||
|
double *c00, *c01, *c02, *c03;
|
||||||
|
double *c40, *c41, *c42, *c43;
|
||||||
|
|
||||||
// Quad registers.
|
// Quad registers.
|
||||||
__m256d va0_3, va4_7;
|
__m256d va0_3, va4_7;
|
||||||
@@ -87,23 +93,20 @@ void bli_dgemm_sandybridge_int_8x4
|
|||||||
__m256d vb;
|
__m256d vb;
|
||||||
__m256d vB0;
|
__m256d vB0;
|
||||||
|
|
||||||
__m256d va0_3b_0, va4_7b_0;
|
__m256d va0_3b_0, va4_7b_0;
|
||||||
__m256d va0_3b_1, va4_7b_1;
|
__m256d va0_3b_1, va4_7b_1;
|
||||||
__m256d va0_3b_2, va4_7b_2;
|
__m256d va0_3b_2, va4_7b_2;
|
||||||
__m256d va0_3b_3, va4_7b_3;
|
__m256d va0_3b_3, va4_7b_3;
|
||||||
|
|
||||||
__m256d va0_3b0, va4_7b0;
|
__m256d va0_3b0, va4_7b0;
|
||||||
__m256d va0_3b1, va4_7b1;
|
__m256d va0_3b1, va4_7b1;
|
||||||
__m256d va0_3b2, va4_7b2;
|
__m256d va0_3b2, va4_7b2;
|
||||||
__m256d va0_3b3, va4_7b3;
|
__m256d va0_3b3, va4_7b3;
|
||||||
|
|
||||||
|
__m256d valpha, vbeta, vtmp;
|
||||||
__m256d valpha, vbeta, vtmp;
|
|
||||||
__m256d vc0_3_0, vc0_3_1, vc0_3_2, vc0_3_3;
|
__m256d vc0_3_0, vc0_3_1, vc0_3_2, vc0_3_3;
|
||||||
__m256d vc4_7_0, vc4_7_1, vc4_7_2, vc4_7_3;
|
__m256d vc4_7_0, vc4_7_1, vc4_7_2, vc4_7_3;
|
||||||
|
|
||||||
__m128d aa, bb;
|
|
||||||
|
|
||||||
__asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"(a) );
|
__asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"(a) );
|
||||||
__asm__ volatile( "prefetcht2 0(%0) \n\t" : :"r"(b_next) );
|
__asm__ volatile( "prefetcht2 0(%0) \n\t" : :"r"(b_next) );
|
||||||
__asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"(c) );
|
__asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"(c) );
|
||||||
@@ -129,19 +132,19 @@ void bli_dgemm_sandybridge_int_8x4
|
|||||||
va4_7b_3 = _mm256_setzero_pd();
|
va4_7b_3 = _mm256_setzero_pd();
|
||||||
|
|
||||||
// Load va0_3
|
// Load va0_3
|
||||||
va0_3 = _mm256_load_pd( a );
|
va0_3 = _mm256_load_pd( a );
|
||||||
// Load va4_7
|
// Load va4_7
|
||||||
va4_7 = _mm256_load_pd( a + 4 );
|
va4_7 = _mm256_load_pd( a + 4 );
|
||||||
|
|
||||||
// Load vb (b0,b1,b2,b3)
|
// Load vb (b0,b1,b2,b3)
|
||||||
vb0 = _mm256_load_pd( b );
|
vb0 = _mm256_load_pd( b );
|
||||||
|
|
||||||
for( i = 0; i < k_iter; ++i )
|
for( i = 0; i < k_iter; ++i )
|
||||||
{
|
{
|
||||||
__asm__ volatile( "prefetcht0 192(%0) \n\t" : :"r"(a) );
|
__asm__ volatile( "prefetcht0 192(%0) \n\t" : :"r"(a) );
|
||||||
|
|
||||||
// Load va0_3 (Prefetch)
|
// Load va0_3 (Prefetch)
|
||||||
vA0_3 = _mm256_load_pd( a + 8 );
|
vA0_3 = _mm256_load_pd( a + 8 );
|
||||||
|
|
||||||
// Iteration 0.
|
// Iteration 0.
|
||||||
vtmp = _mm256_mul_pd( va0_3, vb0 );
|
vtmp = _mm256_mul_pd( va0_3, vb0 );
|
||||||
@@ -151,10 +154,10 @@ void bli_dgemm_sandybridge_int_8x4
|
|||||||
va4_7b_0 = _mm256_add_pd( va4_7b_0, vtmp );
|
va4_7b_0 = _mm256_add_pd( va4_7b_0, vtmp );
|
||||||
|
|
||||||
// Load va4_7 (Prefetch)
|
// Load va4_7 (Prefetch)
|
||||||
vA4_7 = _mm256_load_pd( a + 12 );
|
vA4_7 = _mm256_load_pd( a + 12 );
|
||||||
|
|
||||||
// Shuffle vb (b1,b0,b3,b2)
|
// Shuffle vb (b1,b0,b3,b2)
|
||||||
vb1 = _mm256_shuffle_pd( vb0, vb0, 0x5 );
|
vb1 = _mm256_shuffle_pd( vb0, vb0, 0x5 );
|
||||||
|
|
||||||
vtmp = _mm256_mul_pd( va0_3, vb1 );
|
vtmp = _mm256_mul_pd( va0_3, vb1 );
|
||||||
va0_3b_1 = _mm256_add_pd( va0_3b_1, vtmp );
|
va0_3b_1 = _mm256_add_pd( va0_3b_1, vtmp );
|
||||||
@@ -163,10 +166,10 @@ void bli_dgemm_sandybridge_int_8x4
|
|||||||
va4_7b_1 = _mm256_add_pd( va4_7b_1, vtmp );
|
va4_7b_1 = _mm256_add_pd( va4_7b_1, vtmp );
|
||||||
|
|
||||||
// Permute vb (b3,b2,b1,b0)
|
// Permute vb (b3,b2,b1,b0)
|
||||||
vb2 = _mm256_permute2f128_pd( vb1, vb1, 0x1 );
|
vb2 = _mm256_permute2f128_pd( vb1, vb1, 0x1 );
|
||||||
|
|
||||||
// Load vb (b0,b1,b2,b3) (Prefetch)
|
// Load vb (b0,b1,b2,b3) (Prefetch)
|
||||||
vB0 = _mm256_load_pd( b + 4 );
|
vB0 = _mm256_load_pd( b + 4 );
|
||||||
|
|
||||||
vtmp = _mm256_mul_pd( va0_3, vb2 );
|
vtmp = _mm256_mul_pd( va0_3, vb2 );
|
||||||
va0_3b_2 = _mm256_add_pd( va0_3b_2, vtmp );
|
va0_3b_2 = _mm256_add_pd( va0_3b_2, vtmp );
|
||||||
@@ -175,7 +178,7 @@ void bli_dgemm_sandybridge_int_8x4
|
|||||||
va4_7b_2 = _mm256_add_pd( va4_7b_2, vtmp );
|
va4_7b_2 = _mm256_add_pd( va4_7b_2, vtmp );
|
||||||
|
|
||||||
// Shuffle vb (b3,b2,b1,b0)
|
// Shuffle vb (b3,b2,b1,b0)
|
||||||
vb3 = _mm256_shuffle_pd( vb2, vb2, 0x5 );
|
vb3 = _mm256_shuffle_pd( vb2, vb2, 0x5 );
|
||||||
|
|
||||||
vtmp = _mm256_mul_pd( va0_3, vb3 );
|
vtmp = _mm256_mul_pd( va0_3, vb3 );
|
||||||
va0_3b_3 = _mm256_add_pd( va0_3b_3, vtmp );
|
va0_3b_3 = _mm256_add_pd( va0_3b_3, vtmp );
|
||||||
@@ -186,14 +189,14 @@ void bli_dgemm_sandybridge_int_8x4
|
|||||||
// Iteration 1.
|
// Iteration 1.
|
||||||
|
|
||||||
__asm__ volatile( "prefetcht0 512(%0) \n\t" : :"r"(a) );
|
__asm__ volatile( "prefetcht0 512(%0) \n\t" : :"r"(a) );
|
||||||
|
|
||||||
// Load va0_3 (Next iteration)
|
// Load va0_3 (Next iteration)
|
||||||
va0_3 = _mm256_load_pd( a + 16 );
|
va0_3 = _mm256_load_pd( a + 16 );
|
||||||
|
|
||||||
vtmp = _mm256_mul_pd( vA0_3, vB0 );
|
vtmp = _mm256_mul_pd( vA0_3, vB0 );
|
||||||
va0_3b_0 = _mm256_add_pd( va0_3b_0, vtmp );
|
va0_3b_0 = _mm256_add_pd( va0_3b_0, vtmp );
|
||||||
|
|
||||||
vb1 = _mm256_shuffle_pd( vB0, vB0, 0x5 );
|
vb1 = _mm256_shuffle_pd( vB0, vB0, 0x5 );
|
||||||
|
|
||||||
vtmp = _mm256_mul_pd( vA4_7, vB0 );
|
vtmp = _mm256_mul_pd( vA4_7, vB0 );
|
||||||
va4_7b_0 = _mm256_add_pd( va4_7b_0, vtmp );
|
va4_7b_0 = _mm256_add_pd( va4_7b_0, vtmp );
|
||||||
@@ -202,9 +205,9 @@ void bli_dgemm_sandybridge_int_8x4
|
|||||||
va0_3b_1 = _mm256_add_pd( va0_3b_1, vtmp );
|
va0_3b_1 = _mm256_add_pd( va0_3b_1, vtmp );
|
||||||
|
|
||||||
// Load va4_7 (Next iteration)
|
// Load va4_7 (Next iteration)
|
||||||
va4_7 = _mm256_load_pd( a + 20 );
|
va4_7 = _mm256_load_pd( a + 20 );
|
||||||
|
|
||||||
vb2 = _mm256_permute2f128_pd( vb1, vb1, 0x1 );
|
vb2 = _mm256_permute2f128_pd( vb1, vb1, 0x1 );
|
||||||
|
|
||||||
vtmp = _mm256_mul_pd( vA4_7, vb1 );
|
vtmp = _mm256_mul_pd( vA4_7, vb1 );
|
||||||
va4_7b_1 = _mm256_add_pd( va4_7b_1, vtmp );
|
va4_7b_1 = _mm256_add_pd( va4_7b_1, vtmp );
|
||||||
@@ -212,13 +215,13 @@ void bli_dgemm_sandybridge_int_8x4
|
|||||||
vtmp = _mm256_mul_pd( vA0_3, vb2 );
|
vtmp = _mm256_mul_pd( vA0_3, vb2 );
|
||||||
va0_3b_2 = _mm256_add_pd( va0_3b_2, vtmp );
|
va0_3b_2 = _mm256_add_pd( va0_3b_2, vtmp );
|
||||||
|
|
||||||
vb3 = _mm256_shuffle_pd( vb2, vb2, 0x5 );
|
vb3 = _mm256_shuffle_pd( vb2, vb2, 0x5 );
|
||||||
|
|
||||||
vtmp = _mm256_mul_pd( vA4_7, vb2 );
|
vtmp = _mm256_mul_pd( vA4_7, vb2 );
|
||||||
va4_7b_2 = _mm256_add_pd( va4_7b_2, vtmp );
|
va4_7b_2 = _mm256_add_pd( va4_7b_2, vtmp );
|
||||||
|
|
||||||
// Load vb0(Next iteration)
|
// Load vb0(Next iteration)
|
||||||
vb0 = _mm256_load_pd( b + 8 );
|
vb0 = _mm256_load_pd( b + 8 );
|
||||||
|
|
||||||
vtmp = _mm256_mul_pd( vA0_3, vb3 );
|
vtmp = _mm256_mul_pd( vA0_3, vb3 );
|
||||||
va0_3b_3 = _mm256_add_pd( va0_3b_3, vtmp );
|
va0_3b_3 = _mm256_add_pd( va0_3b_3, vtmp );
|
||||||
@@ -236,12 +239,12 @@ void bli_dgemm_sandybridge_int_8x4
|
|||||||
// Iteration 0.
|
// Iteration 0.
|
||||||
|
|
||||||
// Load va0_3
|
// Load va0_3
|
||||||
va0_3 = _mm256_load_pd( a );
|
va0_3 = _mm256_load_pd( a );
|
||||||
// Load va4_7
|
// Load va4_7
|
||||||
va4_7 = _mm256_load_pd( a + 4 );
|
va4_7 = _mm256_load_pd( a + 4 );
|
||||||
|
|
||||||
// Load vb (b0,b1,b2,b3)
|
// Load vb (b0,b1,b2,b3)
|
||||||
vb = _mm256_load_pd( b );
|
vb = _mm256_load_pd( b );
|
||||||
|
|
||||||
vtmp = _mm256_mul_pd( va0_3, vb );
|
vtmp = _mm256_mul_pd( va0_3, vb );
|
||||||
va0_3b_0 = _mm256_add_pd( va0_3b_0, vtmp );
|
va0_3b_0 = _mm256_add_pd( va0_3b_0, vtmp );
|
||||||
@@ -250,7 +253,7 @@ void bli_dgemm_sandybridge_int_8x4
|
|||||||
va4_7b_0 = _mm256_add_pd( va4_7b_0, vtmp );
|
va4_7b_0 = _mm256_add_pd( va4_7b_0, vtmp );
|
||||||
|
|
||||||
// Shuffle vb (b1,b0,b3,b2)
|
// Shuffle vb (b1,b0,b3,b2)
|
||||||
vb = _mm256_shuffle_pd( vb, vb, 0x5 );
|
vb = _mm256_shuffle_pd( vb, vb, 0x5 );
|
||||||
|
|
||||||
vtmp = _mm256_mul_pd( va0_3, vb );
|
vtmp = _mm256_mul_pd( va0_3, vb );
|
||||||
va0_3b_1 = _mm256_add_pd( va0_3b_1, vtmp );
|
va0_3b_1 = _mm256_add_pd( va0_3b_1, vtmp );
|
||||||
@@ -259,7 +262,7 @@ void bli_dgemm_sandybridge_int_8x4
|
|||||||
va4_7b_1 = _mm256_add_pd( va4_7b_1, vtmp );
|
va4_7b_1 = _mm256_add_pd( va4_7b_1, vtmp );
|
||||||
|
|
||||||
// Permute vb (b3,b2,b1,b0)
|
// Permute vb (b3,b2,b1,b0)
|
||||||
vb = _mm256_permute2f128_pd( vb, vb, 0x1 );
|
vb = _mm256_permute2f128_pd( vb, vb, 0x1 );
|
||||||
|
|
||||||
vtmp = _mm256_mul_pd( va0_3, vb );
|
vtmp = _mm256_mul_pd( va0_3, vb );
|
||||||
va0_3b_2 = _mm256_add_pd( va0_3b_2, vtmp );
|
va0_3b_2 = _mm256_add_pd( va0_3b_2, vtmp );
|
||||||
@@ -268,7 +271,7 @@ void bli_dgemm_sandybridge_int_8x4
|
|||||||
va4_7b_2 = _mm256_add_pd( va4_7b_2, vtmp );
|
va4_7b_2 = _mm256_add_pd( va4_7b_2, vtmp );
|
||||||
|
|
||||||
// Shuffle vb (b3,b2,b1,b0)
|
// Shuffle vb (b3,b2,b1,b0)
|
||||||
vb = _mm256_shuffle_pd( vb, vb, 0x5 );
|
vb = _mm256_shuffle_pd( vb, vb, 0x5 );
|
||||||
|
|
||||||
vtmp = _mm256_mul_pd( va0_3, vb );
|
vtmp = _mm256_mul_pd( va0_3, vb );
|
||||||
va0_3b_3 = _mm256_add_pd( va0_3b_3, vtmp );
|
va0_3b_3 = _mm256_add_pd( va0_3b_3, vtmp );
|
||||||
@@ -309,12 +312,72 @@ void bli_dgemm_sandybridge_int_8x4
|
|||||||
va4_7b1 = _mm256_permute2f128_pd( vtmpa_4_7b_1, vtmpa_4_7b_3, 0x30 );
|
va4_7b1 = _mm256_permute2f128_pd( vtmpa_4_7b_1, vtmpa_4_7b_3, 0x30 );
|
||||||
va4_7b2 = _mm256_permute2f128_pd( vtmpa_4_7b_3, vtmpa_4_7b_1, 0x30 );
|
va4_7b2 = _mm256_permute2f128_pd( vtmpa_4_7b_3, vtmpa_4_7b_1, 0x30 );
|
||||||
|
|
||||||
if( rs_c == 1 )
|
__m128d vzero = _mm_setzero_pd( );
|
||||||
|
|
||||||
|
if( _mm_comieq_sd( _mm256_castpd256_pd128(vbeta), vzero ) )
|
||||||
{
|
{
|
||||||
// Calculate address
|
// Calculate address
|
||||||
c00 = ( c + 0*rs_c + 0*cs_c );
|
c00 = ( c + 0 + 0*cs_c );
|
||||||
|
// Scale by alpha
|
||||||
|
vtmp = _mm256_mul_pd( valpha, va0_3b0);
|
||||||
|
// Store back to memory
|
||||||
|
_mm256_store_pd( c00, vtmp );
|
||||||
|
|
||||||
|
// Calculate address
|
||||||
|
c40 = ( c + 4 + 0*cs_c );
|
||||||
|
// Scale by alpha
|
||||||
|
vtmp = _mm256_mul_pd( valpha, va4_7b0);
|
||||||
|
// Store back to memory
|
||||||
|
_mm256_store_pd( c40, vtmp );
|
||||||
|
|
||||||
|
// Calculate address
|
||||||
|
c01 = ( c + 0 + 1*cs_c );
|
||||||
|
// Scale by alpha
|
||||||
|
vtmp = _mm256_mul_pd( valpha, va0_3b1);
|
||||||
|
// Store back to memory
|
||||||
|
_mm256_store_pd( c01, vtmp );
|
||||||
|
|
||||||
|
// Calculate address
|
||||||
|
c41 = ( c + 4 + 1*cs_c );
|
||||||
|
// Scale by alpha
|
||||||
|
vtmp = _mm256_mul_pd( valpha, va4_7b1);
|
||||||
|
// Store back to memory
|
||||||
|
_mm256_store_pd( c41, vtmp );
|
||||||
|
|
||||||
|
// Calculate address
|
||||||
|
c02 = ( c + 0 + 2*cs_c );
|
||||||
|
// Scale by alpha
|
||||||
|
vtmp = _mm256_mul_pd( valpha, va0_3b2);
|
||||||
|
// Store back to memory
|
||||||
|
_mm256_store_pd( c02, vtmp );
|
||||||
|
|
||||||
|
// Calculate address
|
||||||
|
c42 = ( c + 4 + 2*cs_c );
|
||||||
|
// Scale by alpha
|
||||||
|
vtmp = _mm256_mul_pd( valpha, va4_7b2);
|
||||||
|
// Store back to memory
|
||||||
|
_mm256_store_pd( c42, vtmp );
|
||||||
|
|
||||||
|
// Calculate address
|
||||||
|
c03 = ( c + 0 + 3*cs_c );
|
||||||
|
// Scale by alpha
|
||||||
|
vtmp = _mm256_mul_pd( valpha, va0_3b3);
|
||||||
|
// Store back to memory
|
||||||
|
_mm256_store_pd( c03, vtmp );
|
||||||
|
|
||||||
|
// Calculate address
|
||||||
|
c43 = ( c + 4 + 3*cs_c );
|
||||||
|
// Scale by alpha
|
||||||
|
vtmp = _mm256_mul_pd( valpha, va4_7b3);
|
||||||
|
// Store back to memory
|
||||||
|
_mm256_store_pd( c43, vtmp );
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// Calculate address
|
||||||
|
c00 = ( c + 0 + 0*cs_c );
|
||||||
// Load
|
// Load
|
||||||
//vc0_3_0 = _mm256_load_pd( c + 0*rs_c + 0*cs_c );
|
//vc0_3_0 = _mm256_load_pd( c + 0 + 0*cs_c );
|
||||||
vc0_3_0 = _mm256_load_pd( c00 );
|
vc0_3_0 = _mm256_load_pd( c00 );
|
||||||
// Scale by alpha
|
// Scale by alpha
|
||||||
vtmp = _mm256_mul_pd( valpha, va0_3b0);
|
vtmp = _mm256_mul_pd( valpha, va0_3b0);
|
||||||
@@ -324,11 +387,11 @@ void bli_dgemm_sandybridge_int_8x4
|
|||||||
vc0_3_0 = _mm256_add_pd( vc0_3_0, vtmp );
|
vc0_3_0 = _mm256_add_pd( vc0_3_0, vtmp );
|
||||||
// Store back to memory
|
// Store back to memory
|
||||||
_mm256_store_pd( c00, vc0_3_0 );
|
_mm256_store_pd( c00, vc0_3_0 );
|
||||||
|
|
||||||
// Calculate address
|
// Calculate address
|
||||||
c40 = ( c + 4*rs_c + 0*cs_c );
|
c40 = ( c + 4 + 0*cs_c );
|
||||||
// Load
|
// Load
|
||||||
//vc4_7_0 = _mm256_load_pd( c + 4*rs_c + 0*cs_c );
|
//vc4_7_0 = _mm256_load_pd( c + 4 + 0*cs_c );
|
||||||
vc4_7_0 = _mm256_load_pd( c40 );
|
vc4_7_0 = _mm256_load_pd( c40 );
|
||||||
// Scale by alpha
|
// Scale by alpha
|
||||||
vtmp = _mm256_mul_pd( valpha, va4_7b0);
|
vtmp = _mm256_mul_pd( valpha, va4_7b0);
|
||||||
@@ -338,11 +401,11 @@ void bli_dgemm_sandybridge_int_8x4
|
|||||||
vc4_7_0 = _mm256_add_pd( vc4_7_0, vtmp );
|
vc4_7_0 = _mm256_add_pd( vc4_7_0, vtmp );
|
||||||
// Store back to memory
|
// Store back to memory
|
||||||
_mm256_store_pd( c40, vc4_7_0 );
|
_mm256_store_pd( c40, vc4_7_0 );
|
||||||
|
|
||||||
// Calculate address
|
// Calculate address
|
||||||
c01 = ( c + 0*rs_c + 1*cs_c );
|
c01 = ( c + 0 + 1*cs_c );
|
||||||
// Load
|
// Load
|
||||||
//vc0_3_1 = _mm256_load_pd( c + 0*rs_c + 1*cs_c );
|
//vc0_3_1 = _mm256_load_pd( c + 0 + 1*cs_c );
|
||||||
vc0_3_1 = _mm256_load_pd( c01 );
|
vc0_3_1 = _mm256_load_pd( c01 );
|
||||||
// Scale by alpha
|
// Scale by alpha
|
||||||
vtmp = _mm256_mul_pd( valpha, va0_3b1);
|
vtmp = _mm256_mul_pd( valpha, va0_3b1);
|
||||||
@@ -352,11 +415,11 @@ void bli_dgemm_sandybridge_int_8x4
|
|||||||
vc0_3_1 = _mm256_add_pd( vc0_3_1, vtmp );
|
vc0_3_1 = _mm256_add_pd( vc0_3_1, vtmp );
|
||||||
// Store back to memory
|
// Store back to memory
|
||||||
_mm256_store_pd( c01, vc0_3_1 );
|
_mm256_store_pd( c01, vc0_3_1 );
|
||||||
|
|
||||||
// Calculate address
|
// Calculate address
|
||||||
c41 = ( c + 4*rs_c + 1*cs_c );
|
c41 = ( c + 4 + 1*cs_c );
|
||||||
// Load
|
// Load
|
||||||
//vc4_7_1 = _mm256_load_pd( c + 4*rs_c + 1*cs_c );
|
//vc4_7_1 = _mm256_load_pd( c + 4 + 1*cs_c );
|
||||||
vc4_7_1 = _mm256_load_pd( c41 );
|
vc4_7_1 = _mm256_load_pd( c41 );
|
||||||
// Scale by alpha
|
// Scale by alpha
|
||||||
vtmp = _mm256_mul_pd( valpha, va4_7b1);
|
vtmp = _mm256_mul_pd( valpha, va4_7b1);
|
||||||
@@ -366,11 +429,11 @@ void bli_dgemm_sandybridge_int_8x4
|
|||||||
vc4_7_1 = _mm256_add_pd( vc4_7_1, vtmp );
|
vc4_7_1 = _mm256_add_pd( vc4_7_1, vtmp );
|
||||||
// Store back to memory
|
// Store back to memory
|
||||||
_mm256_store_pd( c41, vc4_7_1 );
|
_mm256_store_pd( c41, vc4_7_1 );
|
||||||
|
|
||||||
// Calculate address
|
// Calculate address
|
||||||
c02 = ( c + 0*rs_c + 2*cs_c );
|
c02 = ( c + 0 + 2*cs_c );
|
||||||
// Load
|
// Load
|
||||||
//vc0_3_2 = _mm256_load_pd( c + 0*rs_c + 2*cs_c );
|
//vc0_3_2 = _mm256_load_pd( c + 0 + 2*cs_c );
|
||||||
vc0_3_2 = _mm256_load_pd( c02 );
|
vc0_3_2 = _mm256_load_pd( c02 );
|
||||||
// Scale by alpha
|
// Scale by alpha
|
||||||
vtmp = _mm256_mul_pd( valpha, va0_3b2);
|
vtmp = _mm256_mul_pd( valpha, va0_3b2);
|
||||||
@@ -380,11 +443,11 @@ void bli_dgemm_sandybridge_int_8x4
|
|||||||
vc0_3_2 = _mm256_add_pd( vc0_3_2, vtmp );
|
vc0_3_2 = _mm256_add_pd( vc0_3_2, vtmp );
|
||||||
// Store back to memory
|
// Store back to memory
|
||||||
_mm256_store_pd( c02, vc0_3_2 );
|
_mm256_store_pd( c02, vc0_3_2 );
|
||||||
|
|
||||||
// Calculate address
|
// Calculate address
|
||||||
c42 = ( c + 4*rs_c + 2*cs_c );
|
c42 = ( c + 4 + 2*cs_c );
|
||||||
// Load
|
// Load
|
||||||
//vc4_7_2 = _mm256_load_pd( c + 4*rs_c + 2*cs_c );
|
//vc4_7_2 = _mm256_load_pd( c + 4 + 2*cs_c );
|
||||||
vc4_7_2 = _mm256_load_pd( c42 );
|
vc4_7_2 = _mm256_load_pd( c42 );
|
||||||
// Scale by alpha
|
// Scale by alpha
|
||||||
vtmp = _mm256_mul_pd( valpha, va4_7b2);
|
vtmp = _mm256_mul_pd( valpha, va4_7b2);
|
||||||
@@ -394,11 +457,11 @@ void bli_dgemm_sandybridge_int_8x4
|
|||||||
vc4_7_2 = _mm256_add_pd( vc4_7_2, vtmp );
|
vc4_7_2 = _mm256_add_pd( vc4_7_2, vtmp );
|
||||||
// Store back to memory
|
// Store back to memory
|
||||||
_mm256_store_pd( c42, vc4_7_2 );
|
_mm256_store_pd( c42, vc4_7_2 );
|
||||||
|
|
||||||
// Calculate address
|
// Calculate address
|
||||||
c03 = ( c + 0*rs_c + 3*cs_c );
|
c03 = ( c + 0 + 3*cs_c );
|
||||||
// Load
|
// Load
|
||||||
//vc0_3_3 = _mm256_load_pd( c + 0*rs_c + 3*cs_c );
|
//vc0_3_3 = _mm256_load_pd( c + 0 + 3*cs_c );
|
||||||
vc0_3_3 = _mm256_load_pd( c03 );
|
vc0_3_3 = _mm256_load_pd( c03 );
|
||||||
// Scale by alpha
|
// Scale by alpha
|
||||||
vtmp = _mm256_mul_pd( valpha, va0_3b3);
|
vtmp = _mm256_mul_pd( valpha, va0_3b3);
|
||||||
@@ -408,11 +471,11 @@ void bli_dgemm_sandybridge_int_8x4
|
|||||||
vc0_3_3 = _mm256_add_pd( vc0_3_3, vtmp );
|
vc0_3_3 = _mm256_add_pd( vc0_3_3, vtmp );
|
||||||
// Store back to memory
|
// Store back to memory
|
||||||
_mm256_store_pd( c03, vc0_3_3 );
|
_mm256_store_pd( c03, vc0_3_3 );
|
||||||
|
|
||||||
// Calculate address
|
// Calculate address
|
||||||
c43 = ( c + 4*rs_c + 3*cs_c );
|
c43 = ( c + 4 + 3*cs_c );
|
||||||
// Load
|
// Load
|
||||||
//vc4_7_3 = _mm256_load_pd( c + 4*rs_c + 3*cs_c );
|
//vc4_7_3 = _mm256_load_pd( c + 4 + 3*cs_c );
|
||||||
vc4_7_3 = _mm256_load_pd( c43 );
|
vc4_7_3 = _mm256_load_pd( c43 );
|
||||||
// Scale by alpha
|
// Scale by alpha
|
||||||
vtmp = _mm256_mul_pd( valpha, va4_7b3);
|
vtmp = _mm256_mul_pd( valpha, va4_7b3);
|
||||||
@@ -422,211 +485,9 @@ void bli_dgemm_sandybridge_int_8x4
|
|||||||
vc4_7_3 = _mm256_add_pd( vc4_7_3, vtmp );
|
vc4_7_3 = _mm256_add_pd( vc4_7_3, vtmp );
|
||||||
// Store back to memory
|
// Store back to memory
|
||||||
_mm256_store_pd( c43, vc4_7_3 );
|
_mm256_store_pd( c43, vc4_7_3 );
|
||||||
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
// Calculate address
|
|
||||||
c00 = ( c + 0*rs_c + 0*cs_c );
|
|
||||||
// Load
|
|
||||||
//vc0_3_0 = _mm256_load_pd( c + 0*rs_c + 0*cs_c );
|
|
||||||
vc0_3_0 = _mm256_set_pd( *(c + 3*rs_c + 0*cs_c ),
|
|
||||||
*(c + 2*rs_c + 0*cs_c ),
|
|
||||||
*(c + 1*rs_c + 0*cs_c ),
|
|
||||||
*(c + 0*rs_c + 0*cs_c ) );
|
|
||||||
// Scale by alpha
|
|
||||||
vtmp = _mm256_mul_pd( valpha, va0_3b0);
|
|
||||||
// Scale by beta
|
|
||||||
vc0_3_0 = _mm256_mul_pd( vbeta, vc0_3_0 );
|
|
||||||
// Add gemm result
|
|
||||||
vc0_3_0 = _mm256_add_pd( vc0_3_0, vtmp );
|
|
||||||
// Store back to memory
|
|
||||||
//_mm256_store_pd( c00, vc0_3_0 );
|
|
||||||
|
|
||||||
aa = _mm256_extractf128_pd( vc0_3_0, 0 ) ;
|
|
||||||
bb = _mm256_extractf128_pd( vc0_3_0, 1 ) ;
|
|
||||||
|
|
||||||
_mm_storel_pd( c + 0*rs_c + 0*cs_c, aa );
|
|
||||||
_mm_storeh_pd( c + 1*rs_c + 0*cs_c, aa );
|
|
||||||
_mm_storel_pd( c + 2*rs_c + 0*cs_c, bb );
|
|
||||||
_mm_storeh_pd( c + 3*rs_c + 0*cs_c, bb );
|
|
||||||
|
|
||||||
// Calculate address
|
|
||||||
c40 = ( c + 4*rs_c + 0*cs_c );
|
|
||||||
// Load
|
|
||||||
//vc4_7_0 = _mm256_load_pd( c + 4*rs_c + 0*cs_c );
|
|
||||||
vc4_7_0 = _mm256_set_pd( *(c + 7*rs_c + 0*cs_c ),
|
|
||||||
*(c + 6*rs_c + 0*cs_c ),
|
|
||||||
*(c + 5*rs_c + 0*cs_c ),
|
|
||||||
*(c + 4*rs_c + 0*cs_c ) );
|
|
||||||
// Scale by alpha
|
|
||||||
vtmp = _mm256_mul_pd( valpha, va4_7b0);
|
|
||||||
// Scale by beta
|
|
||||||
vc4_7_0 = _mm256_mul_pd( vbeta, vc4_7_0 );
|
|
||||||
// Add gemm result
|
|
||||||
vc4_7_0 = _mm256_add_pd( vc4_7_0, vtmp );
|
|
||||||
// Store back to memory
|
|
||||||
//_mm256_store_pd( c40, vc4_7_0 );
|
|
||||||
|
|
||||||
aa = _mm256_extractf128_pd( vc4_7_0, 0 ) ;
|
|
||||||
bb = _mm256_extractf128_pd( vc4_7_0, 1 ) ;
|
|
||||||
|
|
||||||
_mm_storel_pd( c + 4*rs_c + 0*cs_c, aa );
|
|
||||||
_mm_storeh_pd( c + 5*rs_c + 0*cs_c, aa );
|
|
||||||
_mm_storel_pd( c + 6*rs_c + 0*cs_c, bb );
|
|
||||||
_mm_storeh_pd( c + 7*rs_c + 0*cs_c, bb );
|
|
||||||
|
|
||||||
// Calculate address
|
|
||||||
c01 = ( c + 0*rs_c + 1*cs_c );
|
|
||||||
// Load
|
|
||||||
//vc0_3_1 = _mm256_load_pd( c + 0*rs_c + 1*cs_c );
|
|
||||||
vc0_3_1 = _mm256_set_pd( *(c + 3*rs_c + 1*cs_c ),
|
|
||||||
*(c + 2*rs_c + 1*cs_c ),
|
|
||||||
*(c + 1*rs_c + 1*cs_c ),
|
|
||||||
*(c + 0*rs_c + 1*cs_c ) );
|
|
||||||
// Scale by alpha
|
|
||||||
vtmp = _mm256_mul_pd( valpha, va0_3b1);
|
|
||||||
// Scale by beta
|
|
||||||
vc0_3_1 = _mm256_mul_pd( vbeta, vc0_3_1 );
|
|
||||||
// Add gemm result
|
|
||||||
vc0_3_1 = _mm256_add_pd( vc0_3_1, vtmp );
|
|
||||||
// Store back to memory
|
|
||||||
//_mm256_store_pd( c01, vc0_3_1 );
|
|
||||||
|
|
||||||
aa = _mm256_extractf128_pd( vc0_3_1, 0 ) ;
|
|
||||||
bb = _mm256_extractf128_pd( vc0_3_1, 1 ) ;
|
|
||||||
|
|
||||||
_mm_storel_pd( c + 0*rs_c + 1*cs_c, aa );
|
|
||||||
_mm_storeh_pd( c + 1*rs_c + 1*cs_c, aa );
|
|
||||||
_mm_storel_pd( c + 2*rs_c + 1*cs_c, bb );
|
|
||||||
_mm_storeh_pd( c + 3*rs_c + 1*cs_c, bb );
|
|
||||||
|
|
||||||
// Calculate address
|
|
||||||
c41 = ( c + 4*rs_c + 1*cs_c );
|
|
||||||
// Load
|
|
||||||
//vc4_7_1 = _mm256_load_pd( c + 4*rs_c + 1*cs_c );
|
|
||||||
vc4_7_1 = _mm256_set_pd( *(c + 7*rs_c + 1*cs_c ),
|
|
||||||
*(c + 6*rs_c + 1*cs_c ),
|
|
||||||
*(c + 5*rs_c + 1*cs_c ),
|
|
||||||
*(c + 4*rs_c + 1*cs_c ) );
|
|
||||||
// Scale by alpha
|
|
||||||
vtmp = _mm256_mul_pd( valpha, va4_7b1);
|
|
||||||
// Scale by beta
|
|
||||||
vc4_7_1 = _mm256_mul_pd( vbeta, vc4_7_1 );
|
|
||||||
// Add gemm result
|
|
||||||
vc4_7_1 = _mm256_add_pd( vc4_7_1, vtmp );
|
|
||||||
// Store back to memory
|
|
||||||
//_mm256_store_pd( c41, vc4_7_1 );
|
|
||||||
|
|
||||||
aa = _mm256_extractf128_pd( vc4_7_1, 0 ) ;
|
|
||||||
bb = _mm256_extractf128_pd( vc4_7_1, 1 ) ;
|
|
||||||
|
|
||||||
_mm_storel_pd( c + 4*rs_c + 1*cs_c, aa );
|
|
||||||
_mm_storeh_pd( c + 5*rs_c + 1*cs_c, aa );
|
|
||||||
_mm_storel_pd( c + 6*rs_c + 1*cs_c, bb );
|
|
||||||
_mm_storeh_pd( c + 7*rs_c + 1*cs_c, bb );
|
|
||||||
|
|
||||||
// Calculate address
|
|
||||||
c02 = ( c + 0*rs_c + 2*cs_c );
|
|
||||||
// Load
|
|
||||||
//vc0_3_2 = _mm256_load_pd( c + 0*rs_c + 2*cs_c );
|
|
||||||
vc0_3_2 = _mm256_set_pd( *(c + 3*rs_c + 2*cs_c ),
|
|
||||||
*(c + 2*rs_c + 2*cs_c ),
|
|
||||||
*(c + 1*rs_c + 2*cs_c ),
|
|
||||||
*(c + 0*rs_c + 2*cs_c ) );
|
|
||||||
// Scale by alpha
|
|
||||||
vtmp = _mm256_mul_pd( valpha, va0_3b2);
|
|
||||||
// Scale by beta
|
|
||||||
vc0_3_2 = _mm256_mul_pd( vbeta, vc0_3_2 );
|
|
||||||
// Add gemm result
|
|
||||||
vc0_3_2 = _mm256_add_pd( vc0_3_2, vtmp );
|
|
||||||
// Store back to memory
|
|
||||||
//_mm256_store_pd( c02, vc0_3_2 );
|
|
||||||
|
|
||||||
aa = _mm256_extractf128_pd( vc0_3_2, 0 ) ;
|
|
||||||
bb = _mm256_extractf128_pd( vc0_3_2, 1 ) ;
|
|
||||||
|
|
||||||
_mm_storel_pd( c + 0*rs_c + 2*cs_c, aa );
|
|
||||||
_mm_storeh_pd( c + 1*rs_c + 2*cs_c, aa );
|
|
||||||
_mm_storel_pd( c + 2*rs_c + 2*cs_c, bb );
|
|
||||||
_mm_storeh_pd( c + 3*rs_c + 2*cs_c, bb );
|
|
||||||
|
|
||||||
// Calculate address
|
|
||||||
c42 = ( c + 4*rs_c + 2*cs_c );
|
|
||||||
// Load
|
|
||||||
//vc4_7_2 = _mm256_load_pd( c + 4*rs_c + 2*cs_c );
|
|
||||||
vc4_7_2 = _mm256_set_pd( *(c + 7*rs_c + 2*cs_c ),
|
|
||||||
*(c + 6*rs_c + 2*cs_c ),
|
|
||||||
*(c + 5*rs_c + 2*cs_c ),
|
|
||||||
*(c + 4*rs_c + 2*cs_c ) );
|
|
||||||
// Scale by alpha
|
|
||||||
vtmp = _mm256_mul_pd( valpha, va4_7b2);
|
|
||||||
// Scale by beta
|
|
||||||
vc4_7_2 = _mm256_mul_pd( vbeta, vc4_7_2 );
|
|
||||||
// Add gemm result
|
|
||||||
vc4_7_2 = _mm256_add_pd( vc4_7_2, vtmp );
|
|
||||||
// Store back to memory
|
|
||||||
//_mm256_store_pd( c42, vc4_7_2 );
|
|
||||||
|
|
||||||
aa = _mm256_extractf128_pd( vc4_7_2, 0 ) ;
|
|
||||||
bb = _mm256_extractf128_pd( vc4_7_2, 1 ) ;
|
|
||||||
|
|
||||||
_mm_storel_pd( c + 4*rs_c + 2*cs_c, aa );
|
|
||||||
_mm_storeh_pd( c + 5*rs_c + 2*cs_c, aa );
|
|
||||||
_mm_storel_pd( c + 6*rs_c + 2*cs_c, bb );
|
|
||||||
_mm_storeh_pd( c + 7*rs_c + 2*cs_c, bb );
|
|
||||||
|
|
||||||
// Calculate address
|
|
||||||
c03 = ( c + 0*rs_c + 3*cs_c );
|
|
||||||
// Load
|
|
||||||
//vc0_3_3 = _mm256_load_pd( c + 0*rs_c + 3*cs_c );
|
|
||||||
vc0_3_3 = _mm256_set_pd( *(c + 3*rs_c + 3*cs_c ),
|
|
||||||
*(c + 2*rs_c + 3*cs_c ),
|
|
||||||
*(c + 1*rs_c + 3*cs_c ),
|
|
||||||
*(c + 0*rs_c + 3*cs_c ) );
|
|
||||||
// Scale by alpha
|
|
||||||
vtmp = _mm256_mul_pd( valpha, va0_3b3);
|
|
||||||
// Scale by beta
|
|
||||||
vc0_3_3 = _mm256_mul_pd( vbeta, vc0_3_3 );
|
|
||||||
// Add gemm result
|
|
||||||
vc0_3_3 = _mm256_add_pd( vc0_3_3, vtmp );
|
|
||||||
// Store back to memory
|
|
||||||
//_mm256_store_pd( c03, vc0_3_3 );
|
|
||||||
|
|
||||||
aa = _mm256_extractf128_pd( vc0_3_3, 0 ) ;
|
|
||||||
bb = _mm256_extractf128_pd( vc0_3_3, 1 ) ;
|
|
||||||
|
|
||||||
_mm_storel_pd( c + 0*rs_c + 3*cs_c, aa );
|
|
||||||
_mm_storeh_pd( c + 1*rs_c + 3*cs_c, aa );
|
|
||||||
_mm_storel_pd( c + 2*rs_c + 3*cs_c, bb );
|
|
||||||
_mm_storeh_pd( c + 3*rs_c + 3*cs_c, bb );
|
|
||||||
|
|
||||||
// Calculate address
|
|
||||||
c43 = ( c + 4*rs_c + 3*cs_c );
|
|
||||||
// Load
|
|
||||||
//vc4_7_3 = _mm256_load_pd( c + 4*rs_c + 3*cs_c );
|
|
||||||
vc4_7_3 = _mm256_set_pd( *(c + 7*rs_c + 3*cs_c ),
|
|
||||||
*(c + 6*rs_c + 3*cs_c ),
|
|
||||||
*(c + 5*rs_c + 3*cs_c ),
|
|
||||||
*(c + 4*rs_c + 3*cs_c ) );
|
|
||||||
// Scale by alpha
|
|
||||||
vtmp = _mm256_mul_pd( valpha, va4_7b3);
|
|
||||||
// Scale by beta
|
|
||||||
vc4_7_3 = _mm256_mul_pd( vbeta, vc4_7_3 );
|
|
||||||
// Add gemm result
|
|
||||||
vc4_7_3 = _mm256_add_pd( vc4_7_3, vtmp );
|
|
||||||
// Store back to memory
|
|
||||||
//_mm256_store_pd( c43, vc4_7_3 );
|
|
||||||
|
|
||||||
aa = _mm256_extractf128_pd( vc4_7_3, 0 ) ;
|
|
||||||
bb = _mm256_extractf128_pd( vc4_7_3, 1 ) ;
|
|
||||||
|
|
||||||
_mm_storel_pd( c + 4*rs_c + 3*cs_c, aa );
|
|
||||||
_mm_storeh_pd( c + 5*rs_c + 3*cs_c, aa );
|
|
||||||
_mm_storel_pd( c + 6*rs_c + 3*cs_c, bb );
|
|
||||||
_mm_storeh_pd( c + 7*rs_c + 3*cs_c, bb );
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GEMM_UKR_FLUSH_CT( d );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -634,7 +495,9 @@ void bli_dgemm_sandybridge_int_8x4
|
|||||||
#if 0
|
#if 0
|
||||||
void bli_cgemm_sandybridge_int_8x4
|
void bli_cgemm_sandybridge_int_8x4
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
scomplex* restrict alpha,
|
scomplex* restrict alpha,
|
||||||
scomplex* restrict a,
|
scomplex* restrict a,
|
||||||
scomplex* restrict b,
|
scomplex* restrict b,
|
||||||
@@ -652,7 +515,9 @@ void bli_cgemm_sandybridge_int_8x4
|
|||||||
#if 0
|
#if 0
|
||||||
void bli_zgemm_sandybridge_int_4x4
|
void bli_zgemm_sandybridge_int_4x4
|
||||||
(
|
(
|
||||||
dim_t k0,
|
dim_t m,
|
||||||
|
dim_t n,
|
||||||
|
dim_t k,
|
||||||
dcomplex* restrict alpha,
|
dcomplex* restrict alpha,
|
||||||
dcomplex* restrict a,
|
dcomplex* restrict a,
|
||||||
dcomplex* restrict b,
|
dcomplex* restrict b,
|
||||||
|
|||||||
@@ -287,24 +287,28 @@ static int64_t offsets[16] __attribute__((aligned(64))) =
|
|||||||
{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15};
|
{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15};
|
||||||
|
|
||||||
|
|
||||||
void bli_dgemm_skx_asm_16x12_l2(
|
void bli_dgemm_skx_asm_16x12_l2
|
||||||
dim_t k_,
|
(
|
||||||
double* restrict alpha,
|
dim_t m,
|
||||||
double* restrict a,
|
dim_t n,
|
||||||
double* restrict b,
|
dim_t k_,
|
||||||
double* restrict beta,
|
double* restrict alpha,
|
||||||
double* restrict c, inc_t rs_c_, inc_t cs_c_,
|
double* restrict a,
|
||||||
auxinfo_t* data,
|
double* restrict b,
|
||||||
cntx_t* restrict cntx
|
double* restrict beta,
|
||||||
)
|
double* restrict c, inc_t rs_c_, inc_t cs_c_,
|
||||||
|
auxinfo_t* data,
|
||||||
|
cntx_t* restrict cntx
|
||||||
|
)
|
||||||
{
|
{
|
||||||
(void)data;
|
(void)data;
|
||||||
(void)cntx;
|
(void)cntx;
|
||||||
|
|
||||||
const int64_t* offsetPtr = &offsets[0];
|
int64_t k = k_;
|
||||||
const int64_t k = k_;
|
int64_t rs_c = rs_c_;
|
||||||
const int64_t rs_c = rs_c_;
|
int64_t cs_c = cs_c_;
|
||||||
const int64_t cs_c = cs_c_;
|
|
||||||
|
GEMM_UKR_SETUP_CT( d, 16, 12, false );
|
||||||
|
|
||||||
BEGIN_ASM()
|
BEGIN_ASM()
|
||||||
|
|
||||||
@@ -464,62 +468,26 @@ void bli_dgemm_skx_asm_16x12_l2(
|
|||||||
|
|
||||||
MOV(RAX, VAR(cs_c))
|
MOV(RAX, VAR(cs_c))
|
||||||
LEA(RAX, MEM(,RAX,8))
|
LEA(RAX, MEM(,RAX,8))
|
||||||
MOV(RBX, VAR(rs_c))
|
|
||||||
LEA(RBX, MEM(,RBX,8))
|
|
||||||
|
|
||||||
// Check if C is column stride. If not, jump to the slow scattered update
|
VCOMISD(XMM(1), XMM(7))
|
||||||
CMP(RBX, IMM(1))
|
JE(COLSTORBZ)
|
||||||
JNE(SCATTEREDUPDATE)
|
|
||||||
|
|
||||||
VCOMISD(XMM(1), XMM(7))
|
UPDATE_C( 8, 9,10,11)
|
||||||
JE(COLSTORBZ)
|
UPDATE_C(12,13,14,15)
|
||||||
|
UPDATE_C(16,17,18,19)
|
||||||
UPDATE_C( 8, 9,10,11)
|
UPDATE_C(20,21,22,23)
|
||||||
UPDATE_C(12,13,14,15)
|
UPDATE_C(24,25,26,27)
|
||||||
UPDATE_C(16,17,18,19)
|
UPDATE_C(28,29,30,31)
|
||||||
UPDATE_C(20,21,22,23)
|
|
||||||
UPDATE_C(24,25,26,27)
|
|
||||||
UPDATE_C(28,29,30,31)
|
|
||||||
|
|
||||||
JMP(END)
|
|
||||||
LABEL(COLSTORBZ)
|
|
||||||
|
|
||||||
UPDATE_C_BZ( 8, 9,10,11)
|
|
||||||
UPDATE_C_BZ(12,13,14,15)
|
|
||||||
UPDATE_C_BZ(16,17,18,19)
|
|
||||||
UPDATE_C_BZ(20,21,22,23)
|
|
||||||
UPDATE_C_BZ(24,25,26,27)
|
|
||||||
UPDATE_C_BZ(28,29,30,31)
|
|
||||||
|
|
||||||
JMP(END)
|
JMP(END)
|
||||||
LABEL(SCATTEREDUPDATE)
|
LABEL(COLSTORBZ)
|
||||||
|
|
||||||
MOV(RDI, VAR(offsetPtr))
|
UPDATE_C_BZ( 8, 9,10,11)
|
||||||
VMOVDQA64(ZMM(2), MEM(RDI,0*64))
|
UPDATE_C_BZ(12,13,14,15)
|
||||||
VMOVDQA64(ZMM(3), MEM(RDI,1*64))
|
UPDATE_C_BZ(16,17,18,19)
|
||||||
VPBROADCASTQ(ZMM(6), RBX)
|
UPDATE_C_BZ(20,21,22,23)
|
||||||
VPMULLQ(ZMM(2), ZMM(6), ZMM(2))
|
UPDATE_C_BZ(24,25,26,27)
|
||||||
VPMULLQ(ZMM(3), ZMM(6), ZMM(3))
|
UPDATE_C_BZ(28,29,30,31)
|
||||||
|
|
||||||
VCOMISD(XMM(1), XMM(7))
|
|
||||||
JE(SCATTERBZ)
|
|
||||||
|
|
||||||
UPDATE_C_ROW_SCATTERED( 8, 9,10,11)
|
|
||||||
UPDATE_C_ROW_SCATTERED(12,13,14,15)
|
|
||||||
UPDATE_C_ROW_SCATTERED(16,17,18,19)
|
|
||||||
UPDATE_C_ROW_SCATTERED(20,21,22,23)
|
|
||||||
UPDATE_C_ROW_SCATTERED(24,25,26,27)
|
|
||||||
UPDATE_C_ROW_SCATTERED(28,29,30,31)
|
|
||||||
|
|
||||||
JMP(END)
|
|
||||||
LABEL(SCATTERBZ)
|
|
||||||
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED( 8, 9,10,11)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(12,13,14,15)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(16,17,18,19)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(20,21,22,23)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(24,25,26,27)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(28,29,30,31)
|
|
||||||
|
|
||||||
LABEL(END)
|
LABEL(END)
|
||||||
|
|
||||||
@@ -535,8 +503,7 @@ void bli_dgemm_skx_asm_16x12_l2(
|
|||||||
[beta] "m" (beta),
|
[beta] "m" (beta),
|
||||||
[c] "m" (c),
|
[c] "m" (c),
|
||||||
[rs_c] "m" (rs_c),
|
[rs_c] "m" (rs_c),
|
||||||
[cs_c] "m" (cs_c),
|
[cs_c] "m" (cs_c)
|
||||||
[offsetPtr] "m" (offsetPtr)
|
|
||||||
: // register clobber list
|
: // register clobber list
|
||||||
"rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12",
|
"rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12",
|
||||||
"r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5",
|
"r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5",
|
||||||
@@ -545,4 +512,6 @@ void bli_dgemm_skx_asm_16x12_l2(
|
|||||||
"zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29",
|
"zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29",
|
||||||
"zmm30", "zmm31", "memory"
|
"zmm30", "zmm31", "memory"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
GEMM_UKR_FLUSH_CT( d );
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -153,24 +153,28 @@
|
|||||||
static int64_t offsets[16] __attribute__((aligned(64))) =
|
static int64_t offsets[16] __attribute__((aligned(64))) =
|
||||||
{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15};
|
{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15};
|
||||||
|
|
||||||
void bli_dgemm_skx_asm_16x14(
|
void bli_dgemm_skx_asm_16x14
|
||||||
dim_t k_,
|
(
|
||||||
double* restrict alpha,
|
dim_t m,
|
||||||
double* restrict a,
|
dim_t n,
|
||||||
double* restrict b,
|
dim_t k_,
|
||||||
double* restrict beta,
|
double* restrict alpha,
|
||||||
double* restrict c, inc_t rs_c_, inc_t cs_c_,
|
double* restrict a,
|
||||||
auxinfo_t* data,
|
double* restrict b,
|
||||||
cntx_t* restrict cntx
|
double* restrict beta,
|
||||||
)
|
double* restrict c, inc_t rs_c_, inc_t cs_c_,
|
||||||
|
auxinfo_t* data,
|
||||||
|
cntx_t* restrict cntx
|
||||||
|
)
|
||||||
{
|
{
|
||||||
(void)data;
|
(void)data;
|
||||||
(void)cntx;
|
(void)cntx;
|
||||||
|
|
||||||
const int64_t* offsetPtr = &offsets[0];
|
int64_t k = k_;
|
||||||
const int64_t k = k_;
|
int64_t rs_c = rs_c_;
|
||||||
const int64_t rs_c = rs_c_*8;
|
int64_t cs_c = cs_c_;
|
||||||
const int64_t cs_c = cs_c_*8;
|
|
||||||
|
GEMM_UKR_SETUP_CT( d, 16, 14, false );
|
||||||
|
|
||||||
BEGIN_ASM()
|
BEGIN_ASM()
|
||||||
|
|
||||||
@@ -220,6 +224,8 @@ void bli_dgemm_skx_asm_16x14(
|
|||||||
|
|
||||||
MOV(R12, VAR(rs_c))
|
MOV(R12, VAR(rs_c))
|
||||||
MOV(R10, VAR(cs_c))
|
MOV(R10, VAR(cs_c))
|
||||||
|
LEA(R12, MEM(,R12,8))
|
||||||
|
LEA(R10, MEM(,R10,8))
|
||||||
|
|
||||||
MOV(RDI, RSI)
|
MOV(RDI, RSI)
|
||||||
AND(RSI, IMM(3))
|
AND(RSI, IMM(3))
|
||||||
@@ -320,119 +326,41 @@ void bli_dgemm_skx_asm_16x14(
|
|||||||
MOV(RAX, R12)
|
MOV(RAX, R12)
|
||||||
MOV(RBX, R10)
|
MOV(RBX, R10)
|
||||||
|
|
||||||
// Check if C is column stride.
|
VCOMISD(XMM(1), XMM(2))
|
||||||
CMP(RAX, IMM(8))
|
JE(COLSTORBZ)
|
||||||
JNE(SCATTEREDUPDATE)
|
|
||||||
|
|
||||||
VCOMISD(XMM(1), XMM(2))
|
UPDATE_C( 4, 5)
|
||||||
JE(COLSTORBZ)
|
UPDATE_C( 6, 7)
|
||||||
|
UPDATE_C( 8, 9)
|
||||||
UPDATE_C( 4, 5)
|
UPDATE_C(10,11)
|
||||||
UPDATE_C( 6, 7)
|
UPDATE_C(12,13)
|
||||||
UPDATE_C( 8, 9)
|
UPDATE_C(14,15)
|
||||||
UPDATE_C(10,11)
|
UPDATE_C(16,17)
|
||||||
UPDATE_C(12,13)
|
UPDATE_C(18,19)
|
||||||
UPDATE_C(14,15)
|
UPDATE_C(20,21)
|
||||||
UPDATE_C(16,17)
|
UPDATE_C(22,23)
|
||||||
UPDATE_C(18,19)
|
UPDATE_C(24,25)
|
||||||
UPDATE_C(20,21)
|
UPDATE_C(26,27)
|
||||||
UPDATE_C(22,23)
|
UPDATE_C(28,29)
|
||||||
UPDATE_C(24,25)
|
UPDATE_C(30,31)
|
||||||
UPDATE_C(26,27)
|
|
||||||
UPDATE_C(28,29)
|
|
||||||
UPDATE_C(30,31)
|
|
||||||
|
|
||||||
JMP(END)
|
|
||||||
LABEL(COLSTORBZ)
|
|
||||||
|
|
||||||
UPDATE_C_BZ( 4, 5)
|
|
||||||
UPDATE_C_BZ( 6, 7)
|
|
||||||
UPDATE_C_BZ( 8, 9)
|
|
||||||
UPDATE_C_BZ(10,11)
|
|
||||||
UPDATE_C_BZ(12,13)
|
|
||||||
UPDATE_C_BZ(14,15)
|
|
||||||
UPDATE_C_BZ(16,17)
|
|
||||||
UPDATE_C_BZ(18,19)
|
|
||||||
UPDATE_C_BZ(20,21)
|
|
||||||
UPDATE_C_BZ(22,23)
|
|
||||||
UPDATE_C_BZ(24,25)
|
|
||||||
UPDATE_C_BZ(26,27)
|
|
||||||
UPDATE_C_BZ(28,29)
|
|
||||||
UPDATE_C_BZ(30,31)
|
|
||||||
|
|
||||||
JMP(END)
|
JMP(END)
|
||||||
LABEL(SCATTEREDUPDATE)
|
LABEL(COLSTORBZ)
|
||||||
|
|
||||||
VMULPD(ZMM( 4), ZMM( 4), ZMM(0))
|
UPDATE_C_BZ( 4, 5)
|
||||||
VMULPD(ZMM( 5), ZMM( 5), ZMM(0))
|
UPDATE_C_BZ( 6, 7)
|
||||||
VMULPD(ZMM( 6), ZMM( 6), ZMM(0))
|
UPDATE_C_BZ( 8, 9)
|
||||||
VMULPD(ZMM( 7), ZMM( 7), ZMM(0))
|
UPDATE_C_BZ(10,11)
|
||||||
VMULPD(ZMM( 8), ZMM( 8), ZMM(0))
|
UPDATE_C_BZ(12,13)
|
||||||
VMULPD(ZMM( 9), ZMM( 9), ZMM(0))
|
UPDATE_C_BZ(14,15)
|
||||||
VMULPD(ZMM(10), ZMM(10), ZMM(0))
|
UPDATE_C_BZ(16,17)
|
||||||
VMULPD(ZMM(11), ZMM(11), ZMM(0))
|
UPDATE_C_BZ(18,19)
|
||||||
VMULPD(ZMM(12), ZMM(12), ZMM(0))
|
UPDATE_C_BZ(20,21)
|
||||||
VMULPD(ZMM(13), ZMM(13), ZMM(0))
|
UPDATE_C_BZ(22,23)
|
||||||
VMULPD(ZMM(14), ZMM(14), ZMM(0))
|
UPDATE_C_BZ(24,25)
|
||||||
VMULPD(ZMM(15), ZMM(15), ZMM(0))
|
UPDATE_C_BZ(26,27)
|
||||||
VMULPD(ZMM(16), ZMM(16), ZMM(0))
|
UPDATE_C_BZ(28,29)
|
||||||
VMULPD(ZMM(17), ZMM(17), ZMM(0))
|
UPDATE_C_BZ(30,31)
|
||||||
VMULPD(ZMM(18), ZMM(18), ZMM(0))
|
|
||||||
VMULPD(ZMM(19), ZMM(19), ZMM(0))
|
|
||||||
VMULPD(ZMM(20), ZMM(20), ZMM(0))
|
|
||||||
VMULPD(ZMM(21), ZMM(21), ZMM(0))
|
|
||||||
VMULPD(ZMM(22), ZMM(22), ZMM(0))
|
|
||||||
VMULPD(ZMM(23), ZMM(23), ZMM(0))
|
|
||||||
VMULPD(ZMM(24), ZMM(24), ZMM(0))
|
|
||||||
VMULPD(ZMM(25), ZMM(25), ZMM(0))
|
|
||||||
VMULPD(ZMM(26), ZMM(26), ZMM(0))
|
|
||||||
VMULPD(ZMM(27), ZMM(27), ZMM(0))
|
|
||||||
VMULPD(ZMM(28), ZMM(28), ZMM(0))
|
|
||||||
VMULPD(ZMM(29), ZMM(29), ZMM(0))
|
|
||||||
VMULPD(ZMM(30), ZMM(30), ZMM(0))
|
|
||||||
VMULPD(ZMM(31), ZMM(31), ZMM(0))
|
|
||||||
|
|
||||||
VCOMISD(XMM(1), XMM(2))
|
|
||||||
|
|
||||||
MOV(RDI, VAR(offsetPtr))
|
|
||||||
VPBROADCASTQ(ZMM(0), RAX)
|
|
||||||
VPMULLQ(ZMM(2), ZMM(0), MEM(RDI))
|
|
||||||
VPMULLQ(ZMM(3), ZMM(0), MEM(RDI,64))
|
|
||||||
|
|
||||||
JE(SCATTERBZ)
|
|
||||||
|
|
||||||
UPDATE_C_COL_SCATTERED( 4, 5)
|
|
||||||
UPDATE_C_COL_SCATTERED( 6, 7)
|
|
||||||
UPDATE_C_COL_SCATTERED( 8, 9)
|
|
||||||
UPDATE_C_COL_SCATTERED(10,11)
|
|
||||||
UPDATE_C_COL_SCATTERED(12,13)
|
|
||||||
UPDATE_C_COL_SCATTERED(14,15)
|
|
||||||
UPDATE_C_COL_SCATTERED(16,17)
|
|
||||||
UPDATE_C_COL_SCATTERED(18,19)
|
|
||||||
UPDATE_C_COL_SCATTERED(20,21)
|
|
||||||
UPDATE_C_COL_SCATTERED(22,23)
|
|
||||||
UPDATE_C_COL_SCATTERED(24,25)
|
|
||||||
UPDATE_C_COL_SCATTERED(26,27)
|
|
||||||
UPDATE_C_COL_SCATTERED(28,29)
|
|
||||||
UPDATE_C_COL_SCATTERED(30,31)
|
|
||||||
|
|
||||||
JMP(END)
|
|
||||||
LABEL(SCATTERBZ)
|
|
||||||
|
|
||||||
UPDATE_C_BZ_COL_SCATTERED( 4, 5)
|
|
||||||
UPDATE_C_BZ_COL_SCATTERED( 6, 7)
|
|
||||||
UPDATE_C_BZ_COL_SCATTERED( 8, 9)
|
|
||||||
UPDATE_C_BZ_COL_SCATTERED(10,11)
|
|
||||||
UPDATE_C_BZ_COL_SCATTERED(12,13)
|
|
||||||
UPDATE_C_BZ_COL_SCATTERED(14,15)
|
|
||||||
UPDATE_C_BZ_COL_SCATTERED(16,17)
|
|
||||||
UPDATE_C_BZ_COL_SCATTERED(18,19)
|
|
||||||
UPDATE_C_BZ_COL_SCATTERED(20,21)
|
|
||||||
UPDATE_C_BZ_COL_SCATTERED(22,23)
|
|
||||||
UPDATE_C_BZ_COL_SCATTERED(24,25)
|
|
||||||
UPDATE_C_BZ_COL_SCATTERED(26,27)
|
|
||||||
UPDATE_C_BZ_COL_SCATTERED(28,29)
|
|
||||||
UPDATE_C_BZ_COL_SCATTERED(30,31)
|
|
||||||
|
|
||||||
LABEL(END)
|
LABEL(END)
|
||||||
|
|
||||||
@@ -449,8 +377,7 @@ void bli_dgemm_skx_asm_16x14(
|
|||||||
[beta] "m" (beta),
|
[beta] "m" (beta),
|
||||||
[c] "m" (c),
|
[c] "m" (c),
|
||||||
[rs_c] "m" (rs_c),
|
[rs_c] "m" (rs_c),
|
||||||
[cs_c] "m" (cs_c),
|
[cs_c] "m" (cs_c)
|
||||||
[offsetPtr] "m" (offsetPtr)
|
|
||||||
: // register clobber list
|
: // register clobber list
|
||||||
"rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12",
|
"rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12",
|
||||||
"r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5",
|
"r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5",
|
||||||
@@ -459,4 +386,6 @@ void bli_dgemm_skx_asm_16x14(
|
|||||||
"zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29",
|
"zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29",
|
||||||
"zmm30", "zmm31", "memory"
|
"zmm30", "zmm31", "memory"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
GEMM_UKR_FLUSH_CT( d );
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -317,24 +317,28 @@ ahead*/
|
|||||||
static int64_t offsets[16] __attribute__((aligned(64))) =
|
static int64_t offsets[16] __attribute__((aligned(64))) =
|
||||||
{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15};
|
{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15};
|
||||||
|
|
||||||
void bli_sgemm_skx_asm_32x12_l2(
|
void bli_sgemm_skx_asm_32x12_l2
|
||||||
dim_t k_,
|
(
|
||||||
float* restrict alpha,
|
dim_t m,
|
||||||
float* restrict a,
|
dim_t n,
|
||||||
float* restrict b,
|
dim_t k_,
|
||||||
float* restrict beta,
|
float* restrict alpha,
|
||||||
float* restrict c, inc_t rs_c_, inc_t cs_c_,
|
float* restrict a,
|
||||||
auxinfo_t* data,
|
float* restrict b,
|
||||||
cntx_t* restrict cntx
|
float* restrict beta,
|
||||||
)
|
float* restrict c, inc_t rs_c_, inc_t cs_c_,
|
||||||
|
auxinfo_t* data,
|
||||||
|
cntx_t* restrict cntx
|
||||||
|
)
|
||||||
{
|
{
|
||||||
(void)data;
|
(void)data;
|
||||||
(void)cntx;
|
(void)cntx;
|
||||||
|
|
||||||
const int64_t* offsetPtr = &offsets[0];
|
int64_t k = k_;
|
||||||
const int64_t k = k_;
|
int64_t rs_c = rs_c_;
|
||||||
const int64_t rs_c = rs_c_;
|
int64_t cs_c = cs_c_;
|
||||||
const int64_t cs_c = cs_c_;
|
|
||||||
|
GEMM_UKR_SETUP_CT( s, 32, 12, false );
|
||||||
|
|
||||||
BEGIN_ASM()
|
BEGIN_ASM()
|
||||||
|
|
||||||
@@ -381,7 +385,7 @@ void bli_sgemm_skx_asm_32x12_l2(
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef PREFETCH_B_BEFORE
|
#ifdef PREFETCH_B_BEFORE
|
||||||
/* Prefetching 3 cachlines of B (4 iterations worth of data
|
/* Prefetching 3 cachlines of B (4 iterations worth of data
|
||||||
(12 (NR) x 4 (sizeof(float)) x 4 iter /64 = 3 cachelines) */
|
(12 (NR) x 4 (sizeof(float)) x 4 iter /64 = 3 cachelines) */
|
||||||
PREFETCH(0, MEM(RBX,0*64))
|
PREFETCH(0, MEM(RBX,0*64))
|
||||||
PREFETCH(0, MEM(RBX,1*64))
|
PREFETCH(0, MEM(RBX,1*64))
|
||||||
@@ -485,66 +489,26 @@ void bli_sgemm_skx_asm_32x12_l2(
|
|||||||
|
|
||||||
MOV(RAX, VAR(cs_c))
|
MOV(RAX, VAR(cs_c))
|
||||||
LEA(RAX, MEM(,RAX,4))
|
LEA(RAX, MEM(,RAX,4))
|
||||||
MOV(RBX, VAR(rs_c))
|
|
||||||
LEA(RBX, MEM(,RBX,4))
|
|
||||||
|
|
||||||
|
VCOMISS(XMM(1), XMM(7))
|
||||||
|
JE(COLSTORBZ)
|
||||||
|
|
||||||
// Check if C is column major (rs_c = 1). If not, jump to the slow scattered update
|
UPDATE_C( 8, 9,10,11)
|
||||||
CMP(RBX, IMM(4))
|
UPDATE_C(12,13,14,15)
|
||||||
JNE(SCATTEREDUPDATE)
|
UPDATE_C(16,17,18,19)
|
||||||
|
UPDATE_C(20,21,22,23)
|
||||||
VCOMISS(XMM(1), XMM(7))
|
UPDATE_C(24,25,26,27)
|
||||||
JE(COLSTORBZ)
|
UPDATE_C(28,29,30,31)
|
||||||
|
|
||||||
UPDATE_C( 8, 9,10,11)
|
|
||||||
UPDATE_C(12,13,14,15)
|
|
||||||
UPDATE_C(16,17,18,19)
|
|
||||||
UPDATE_C(20,21,22,23)
|
|
||||||
UPDATE_C(24,25,26,27)
|
|
||||||
UPDATE_C(28,29,30,31)
|
|
||||||
|
|
||||||
JMP(END)
|
|
||||||
LABEL(COLSTORBZ)
|
|
||||||
|
|
||||||
UPDATE_C_BZ( 8, 9,10,11)
|
|
||||||
UPDATE_C_BZ(12,13,14,15)
|
|
||||||
UPDATE_C_BZ(16,17,18,19)
|
|
||||||
UPDATE_C_BZ(20,21,22,23)
|
|
||||||
UPDATE_C_BZ(24,25,26,27)
|
|
||||||
UPDATE_C_BZ(28,29,30,31)
|
|
||||||
|
|
||||||
JMP(END)
|
JMP(END)
|
||||||
LABEL(SCATTEREDUPDATE)
|
LABEL(COLSTORBZ)
|
||||||
|
|
||||||
LEA(RDX, MEM(RCX,RBX,8))
|
UPDATE_C_BZ( 8, 9,10,11)
|
||||||
LEA(RDX, MEM(RDX,RBX,8))
|
UPDATE_C_BZ(12,13,14,15)
|
||||||
|
UPDATE_C_BZ(16,17,18,19)
|
||||||
MOV(RDI, VAR(offsetPtr))
|
UPDATE_C_BZ(20,21,22,23)
|
||||||
VMOVDQA64(ZMM(2), MEM(RDI,0*64))
|
UPDATE_C_BZ(24,25,26,27)
|
||||||
VMOVDQA64(ZMM(3), MEM(RDI,1*64))
|
UPDATE_C_BZ(28,29,30,31)
|
||||||
VPBROADCASTQ(ZMM(6), RBX)
|
|
||||||
VPMULLQ(ZMM(2), ZMM(6), ZMM(2))
|
|
||||||
VPMULLQ(ZMM(3), ZMM(6), ZMM(3))
|
|
||||||
|
|
||||||
VCOMISS(XMM(1), XMM(7))
|
|
||||||
JE(SCATTERBZ)
|
|
||||||
|
|
||||||
UPDATE_C_ROW_SCATTERED( 8, 9,10,11)
|
|
||||||
UPDATE_C_ROW_SCATTERED(12,13,14,15)
|
|
||||||
UPDATE_C_ROW_SCATTERED(16,17,18,19)
|
|
||||||
UPDATE_C_ROW_SCATTERED(20,21,22,23)
|
|
||||||
UPDATE_C_ROW_SCATTERED(24,25,26,27)
|
|
||||||
UPDATE_C_ROW_SCATTERED(28,29,30,31)
|
|
||||||
|
|
||||||
JMP(END)
|
|
||||||
LABEL(SCATTERBZ)
|
|
||||||
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED( 8, 9,10,11)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(12,13,14,15)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(16,17,18,19)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(20,21,22,23)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(24,25,26,27)
|
|
||||||
UPDATE_C_BZ_ROW_SCATTERED(28,29,30,31)
|
|
||||||
|
|
||||||
LABEL(END)
|
LABEL(END)
|
||||||
|
|
||||||
@@ -560,8 +524,7 @@ void bli_sgemm_skx_asm_32x12_l2(
|
|||||||
[beta] "m" (beta),
|
[beta] "m" (beta),
|
||||||
[c] "m" (c),
|
[c] "m" (c),
|
||||||
[rs_c] "m" (rs_c),
|
[rs_c] "m" (rs_c),
|
||||||
[cs_c] "m" (cs_c),
|
[cs_c] "m" (cs_c)
|
||||||
[offsetPtr] "m" (offsetPtr)
|
|
||||||
: // register clobber list
|
: // register clobber list
|
||||||
"rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12",
|
"rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12",
|
||||||
"r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5",
|
"r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5",
|
||||||
@@ -570,4 +533,6 @@ void bli_sgemm_skx_asm_32x12_l2(
|
|||||||
"zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29",
|
"zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29",
|
||||||
"zmm30", "zmm31", "memory"
|
"zmm30", "zmm31", "memory"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
GEMM_UKR_FLUSH_CT( s );
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,6 +42,8 @@
|
|||||||
\
|
\
|
||||||
void PASTEMAC3(ch,opname,arch,suf) \
|
void PASTEMAC3(ch,opname,arch,suf) \
|
||||||
( \
|
( \
|
||||||
|
dim_t m, \
|
||||||
|
dim_t n, \
|
||||||
dim_t k, \
|
dim_t k, \
|
||||||
ctype* restrict alpha, \
|
ctype* restrict alpha, \
|
||||||
ctype* restrict a, \
|
ctype* restrict a, \
|
||||||
@@ -59,9 +61,6 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
|||||||
\
|
\
|
||||||
const inc_t packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \
|
const inc_t packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \
|
||||||
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
|
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
|
||||||
\
|
|
||||||
const dim_t m = mr; \
|
|
||||||
const dim_t n = nr; \
|
|
||||||
\
|
\
|
||||||
const inc_t cs_a = packmr; \
|
const inc_t cs_a = packmr; \
|
||||||
\
|
\
|
||||||
|
|||||||
@@ -87,6 +87,8 @@ PASTEMAC(d,fprintm)( stdout, "gemmtrsm_ukr: b11", mr, 2*nr, \
|
|||||||
/* upper: b11 = alpha * b11 - a12 * b21; */ \
|
/* upper: b11 = alpha * b11 - a12 * b21; */ \
|
||||||
gemm_ukr \
|
gemm_ukr \
|
||||||
( \
|
( \
|
||||||
|
mr, \
|
||||||
|
nr, \
|
||||||
k, \
|
k, \
|
||||||
minus_one, \
|
minus_one, \
|
||||||
a1x, \
|
a1x, \
|
||||||
|
|||||||
@@ -44,6 +44,8 @@
|
|||||||
\
|
\
|
||||||
void PASTEMAC3(ch,opname,arch,suf) \
|
void PASTEMAC3(ch,opname,arch,suf) \
|
||||||
( \
|
( \
|
||||||
|
dim_t m, \
|
||||||
|
dim_t n, \
|
||||||
dim_t k, \
|
dim_t k, \
|
||||||
ctype* restrict alpha, \
|
ctype* restrict alpha, \
|
||||||
ctype* restrict a, \
|
ctype* restrict a, \
|
||||||
@@ -107,8 +109,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
|||||||
\
|
\
|
||||||
if ( PASTEMAC(ch,eq0)( *beta ) ) \
|
if ( PASTEMAC(ch,eq0)( *beta ) ) \
|
||||||
{ \
|
{ \
|
||||||
for ( dim_t i = 0; i < mr; ++i ) \
|
for ( dim_t i = 0; i < m; ++i ) \
|
||||||
for ( dim_t j = 0; j < nr; ++j ) \
|
for ( dim_t j = 0; j < n; ++j ) \
|
||||||
PASTEMAC(ch,copys) \
|
PASTEMAC(ch,copys) \
|
||||||
( \
|
( \
|
||||||
ab[ i*rs_ab + j*cs_ab ], \
|
ab[ i*rs_ab + j*cs_ab ], \
|
||||||
@@ -117,8 +119,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
|||||||
} \
|
} \
|
||||||
else \
|
else \
|
||||||
{ \
|
{ \
|
||||||
for ( dim_t i = 0; i < mr; ++i ) \
|
for ( dim_t i = 0; i < m; ++i ) \
|
||||||
for ( dim_t j = 0; j < nr; ++j ) \
|
for ( dim_t j = 0; j < n; ++j ) \
|
||||||
PASTEMAC(ch,xpbys) \
|
PASTEMAC(ch,xpbys) \
|
||||||
( \
|
( \
|
||||||
ab[ i*rs_ab + j*cs_ab ], \
|
ab[ i*rs_ab + j*cs_ab ], \
|
||||||
@@ -133,8 +135,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
|||||||
\
|
\
|
||||||
if ( PASTEMAC(ch,eq0)( *beta ) ) \
|
if ( PASTEMAC(ch,eq0)( *beta ) ) \
|
||||||
{ \
|
{ \
|
||||||
for ( dim_t j = 0; j < nr; ++j ) \
|
for ( dim_t j = 0; j < n; ++j ) \
|
||||||
for ( dim_t i = 0; i < mr; ++i ) \
|
for ( dim_t i = 0; i < m; ++i ) \
|
||||||
PASTEMAC(ch,copys) \
|
PASTEMAC(ch,copys) \
|
||||||
( \
|
( \
|
||||||
ab[ i*rs_ab + j*cs_ab ], \
|
ab[ i*rs_ab + j*cs_ab ], \
|
||||||
@@ -143,8 +145,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
|||||||
} \
|
} \
|
||||||
else \
|
else \
|
||||||
{ \
|
{ \
|
||||||
for ( dim_t j = 0; j < nr; ++j ) \
|
for ( dim_t j = 0; j < n; ++j ) \
|
||||||
for ( dim_t i = 0; i < mr; ++i ) \
|
for ( dim_t i = 0; i < m; ++i ) \
|
||||||
PASTEMAC(ch,xpbys) \
|
PASTEMAC(ch,xpbys) \
|
||||||
( \
|
( \
|
||||||
ab[ i*rs_ab + j*cs_ab ], \
|
ab[ i*rs_ab + j*cs_ab ], \
|
||||||
@@ -171,6 +173,8 @@ GENTFUNC( dcomplex, z, gemm, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, 4, 4 )
|
|||||||
\
|
\
|
||||||
void PASTEMAC3(ch,opname,arch,suf) \
|
void PASTEMAC3(ch,opname,arch,suf) \
|
||||||
( \
|
( \
|
||||||
|
dim_t m, \
|
||||||
|
dim_t n, \
|
||||||
dim_t k, \
|
dim_t k, \
|
||||||
ctype* restrict alpha, \
|
ctype* restrict alpha, \
|
||||||
ctype* restrict a, \
|
ctype* restrict a, \
|
||||||
@@ -188,9 +192,6 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
|||||||
\
|
\
|
||||||
const inc_t packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \
|
const inc_t packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \
|
||||||
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
|
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
|
||||||
\
|
|
||||||
const dim_t m = mr; \
|
|
||||||
const dim_t n = nr; \
|
|
||||||
\
|
\
|
||||||
const inc_t cs_a = packmr; \
|
const inc_t cs_a = packmr; \
|
||||||
\
|
\
|
||||||
|
|||||||
@@ -52,6 +52,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
|||||||
{ \
|
{ \
|
||||||
const num_t dt = PASTEMAC(ch,type); \
|
const num_t dt = PASTEMAC(ch,type); \
|
||||||
\
|
\
|
||||||
|
const inc_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \
|
||||||
|
const inc_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \
|
||||||
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
|
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \
|
||||||
\
|
\
|
||||||
const inc_t rs_b = packnr; \
|
const inc_t rs_b = packnr; \
|
||||||
@@ -68,6 +70,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
|||||||
/* upper: b11 = alpha * b11 - a12 * b21; */ \
|
/* upper: b11 = alpha * b11 - a12 * b21; */ \
|
||||||
gemm_ukr \
|
gemm_ukr \
|
||||||
( \
|
( \
|
||||||
|
mr, \
|
||||||
|
nr, \
|
||||||
k, \
|
k, \
|
||||||
minus_one, \
|
minus_one, \
|
||||||
a1x, \
|
a1x, \
|
||||||
|
|||||||
@@ -39,6 +39,8 @@
|
|||||||
\
|
\
|
||||||
void PASTEMAC3(ch,opname,arch,suf) \
|
void PASTEMAC3(ch,opname,arch,suf) \
|
||||||
( \
|
( \
|
||||||
|
dim_t m, \
|
||||||
|
dim_t n, \
|
||||||
dim_t k, \
|
dim_t k, \
|
||||||
ctype* restrict alpha, \
|
ctype* restrict alpha, \
|
||||||
ctype* restrict a, \
|
ctype* restrict a, \
|
||||||
@@ -59,6 +61,9 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
|||||||
\
|
\
|
||||||
const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \
|
const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \
|
||||||
const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \
|
const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \
|
||||||
|
\
|
||||||
|
const dim_t mr_r = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \
|
||||||
|
const dim_t nr_r = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \
|
||||||
\
|
\
|
||||||
const dim_t k2 = 2 * k; \
|
const dim_t k2 = 2 * k; \
|
||||||
\
|
\
|
||||||
@@ -118,6 +123,11 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
|||||||
else if ( bli_is_gen_stored( rs_c, cs_c ) ) using_ct = TRUE; \
|
else if ( bli_is_gen_stored( rs_c, cs_c ) ) using_ct = TRUE; \
|
||||||
else using_ct = FALSE; \
|
else using_ct = FALSE; \
|
||||||
\
|
\
|
||||||
|
\
|
||||||
|
/* If we are not computing a full micro-tile, then we must write to
|
||||||
|
ct and then accumulate to c afterwards. */ \
|
||||||
|
if ( mr != m || nr != n ) using_ct = TRUE; \
|
||||||
|
\
|
||||||
\
|
\
|
||||||
if ( using_ct ) \
|
if ( using_ct ) \
|
||||||
{ \
|
{ \
|
||||||
@@ -149,6 +159,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
|||||||
/* c = beta * c + alpha_r * a * b; */ \
|
/* c = beta * c + alpha_r * a * b; */ \
|
||||||
rgemm_ukr \
|
rgemm_ukr \
|
||||||
( \
|
( \
|
||||||
|
mr_r, \
|
||||||
|
nr_r, \
|
||||||
k2, \
|
k2, \
|
||||||
alpha_r, \
|
alpha_r, \
|
||||||
a_r, \
|
a_r, \
|
||||||
@@ -164,8 +176,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
|||||||
/* Accumulate the final result in ct back to c. */ \
|
/* Accumulate the final result in ct back to c. */ \
|
||||||
if ( PASTEMAC(ch,eq1)( *beta ) ) \
|
if ( PASTEMAC(ch,eq1)( *beta ) ) \
|
||||||
{ \
|
{ \
|
||||||
for ( j = 0; j < nr; ++j ) \
|
for ( j = 0; j < n; ++j ) \
|
||||||
for ( i = 0; i < mr; ++i ) \
|
for ( i = 0; i < m; ++i ) \
|
||||||
{ \
|
{ \
|
||||||
PASTEMAC(ch,adds)( *(ct + i*rs_ct + j*cs_ct), \
|
PASTEMAC(ch,adds)( *(ct + i*rs_ct + j*cs_ct), \
|
||||||
*(c + i*rs_c + j*cs_c ) ); \
|
*(c + i*rs_c + j*cs_c ) ); \
|
||||||
@@ -173,8 +185,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
|||||||
} \
|
} \
|
||||||
else if ( PASTEMAC(ch,eq0)( *beta ) ) \
|
else if ( PASTEMAC(ch,eq0)( *beta ) ) \
|
||||||
{ \
|
{ \
|
||||||
for ( j = 0; j < nr; ++j ) \
|
for ( j = 0; j < n; ++j ) \
|
||||||
for ( i = 0; i < mr; ++i ) \
|
for ( i = 0; i < m; ++i ) \
|
||||||
{ \
|
{ \
|
||||||
PASTEMAC(ch,copys)( *(ct + i*rs_ct + j*cs_ct), \
|
PASTEMAC(ch,copys)( *(ct + i*rs_ct + j*cs_ct), \
|
||||||
*(c + i*rs_c + j*cs_c ) ); \
|
*(c + i*rs_c + j*cs_c ) ); \
|
||||||
@@ -182,8 +194,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
|||||||
} \
|
} \
|
||||||
else \
|
else \
|
||||||
{ \
|
{ \
|
||||||
for ( j = 0; j < nr; ++j ) \
|
for ( j = 0; j < n; ++j ) \
|
||||||
for ( i = 0; i < mr; ++i ) \
|
for ( i = 0; i < m; ++i ) \
|
||||||
{ \
|
{ \
|
||||||
PASTEMAC(ch,xpbys)( *(ct + i*rs_ct + j*cs_ct), \
|
PASTEMAC(ch,xpbys)( *(ct + i*rs_ct + j*cs_ct), \
|
||||||
*beta, \
|
*beta, \
|
||||||
@@ -215,6 +227,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
|||||||
/* c = beta * c + alpha_r * a * b; */ \
|
/* c = beta * c + alpha_r * a * b; */ \
|
||||||
rgemm_ukr \
|
rgemm_ukr \
|
||||||
( \
|
( \
|
||||||
|
mr_r, \
|
||||||
|
nr_r, \
|
||||||
k2, \
|
k2, \
|
||||||
alpha_r, \
|
alpha_r, \
|
||||||
a_r, \
|
a_r, \
|
||||||
|
|||||||
@@ -153,6 +153,8 @@ void PASTEMAC3(ch,opname,arch,suf) \
|
|||||||
upper: bt = -1.0 * a12 * b21; */ \
|
upper: bt = -1.0 * a12 * b21; */ \
|
||||||
rgemm_ukr \
|
rgemm_ukr \
|
||||||
( \
|
( \
|
||||||
|
mr_r, \
|
||||||
|
nr_r, \
|
||||||
k2, \
|
k2, \
|
||||||
minus_one_r, \
|
minus_one_r, \
|
||||||
a1x_r, \
|
a1x_r, \
|
||||||
|
|||||||
267
test/syrk_diagonal/complex_math.hpp
Normal file
267
test/syrk_diagonal/complex_math.hpp
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
#include <cmath>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "blis.h"
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct is_complex : std::false_type {};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct is_complex<scomplex> : std::true_type {};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct is_complex<dcomplex> : std::true_type {};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct is_real : std::integral_constant<bool,!is_complex<T>::value> {};
|
||||||
|
|
||||||
|
template <typename T> struct make_complex;
|
||||||
|
|
||||||
|
template <> struct make_complex<float > { using type = scomplex; };
|
||||||
|
template <> struct make_complex<double > { using type = dcomplex; };
|
||||||
|
template <> struct make_complex<scomplex> { using type = scomplex; };
|
||||||
|
template <> struct make_complex<dcomplex> { using type = dcomplex; };
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using make_complex_t = typename make_complex<T>::type;
|
||||||
|
|
||||||
|
template <typename T> struct make_real;
|
||||||
|
|
||||||
|
template <> struct make_real<float > { using type = float; };
|
||||||
|
template <> struct make_real<double > { using type = double; };
|
||||||
|
template <> struct make_real<scomplex> { using type = float; };
|
||||||
|
template <> struct make_real<dcomplex> { using type = double; };
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using make_real_t = typename make_real<T>::type;
|
||||||
|
|
||||||
|
template <typename T, bool Cond>
|
||||||
|
struct make_complex_if : std::conditional<Cond,make_complex_t<T>,make_real_t<T>> {};
|
||||||
|
|
||||||
|
template <typename T, bool Cond>
|
||||||
|
using make_complex_if_t = typename make_complex_if<T,Cond>::type;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct real_imag_part
|
||||||
|
{
|
||||||
|
real_imag_part& operator=(T) { return *this; }
|
||||||
|
|
||||||
|
operator T() const { return T(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<std::is_arithmetic<typename std::remove_cv<T>::type>::value,T&> real(T& x) { return x; }
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<std::is_arithmetic<T>::value,real_imag_part<T>> imag(T x) { return {}; }
|
||||||
|
|
||||||
|
inline float& real(scomplex& x) { return x.real; }
|
||||||
|
|
||||||
|
inline float& imag(scomplex& x) { return x.imag; }
|
||||||
|
|
||||||
|
inline double& real(dcomplex& x) { return x.real; }
|
||||||
|
|
||||||
|
inline double& imag(dcomplex& x) { return x.imag; }
|
||||||
|
|
||||||
|
inline const float& real(const scomplex& x) { return x.real; }
|
||||||
|
|
||||||
|
inline const float& imag(const scomplex& x) { return x.imag; }
|
||||||
|
|
||||||
|
inline const double& real(const dcomplex& x) { return x.real; }
|
||||||
|
|
||||||
|
inline const double& imag(const dcomplex& x) { return x.imag; }
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<is_real<T>::value,T> conj(T x) { return x; }
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<is_complex<T>::value,T> conj(const T& x) { return {x.real, -x.imag}; }
|
||||||
|
|
||||||
|
template <typename T, typename U, typename=void>
|
||||||
|
struct convert_impl;
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
struct convert_impl<T, U, std::enable_if_t<is_real<T>::value && is_real<U>::value>>
|
||||||
|
{
|
||||||
|
void operator()(T x, U& y) const { y = x; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
struct convert_impl<T, U, std::enable_if_t<is_real<T>::value && is_complex<U>::value>>
|
||||||
|
{
|
||||||
|
void operator()(T x, U& y) const { y.real = x; y.imag = 0; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
struct convert_impl<T, U, std::enable_if_t<is_complex<T>::value && is_real<U>::value>>
|
||||||
|
{
|
||||||
|
void operator()(T x, U& y) const { y = x.real; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
struct convert_impl<T, U, std::enable_if_t<is_complex<T>::value && is_complex<U>::value>>
|
||||||
|
{
|
||||||
|
void operator()(T x, U& y) const { y.real = x.real; y.imag = x.imag; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename U, typename T>
|
||||||
|
U convert(T x)
|
||||||
|
{
|
||||||
|
U y;
|
||||||
|
convert_impl<T,U>{}(x,y);
|
||||||
|
return y;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U, typename T>
|
||||||
|
auto convert_prec(T x) -> make_complex_if_t<U,is_complex<T>::value>
|
||||||
|
{
|
||||||
|
return convert<make_complex_if_t<U,is_complex<T>::value>>(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define COMPLEX_MATH_OPS(rtype, ctype) \
|
||||||
|
\
|
||||||
|
inline bool operator==(rtype x, ctype y) \
|
||||||
|
{ \
|
||||||
|
return x == y.real && y.imag == 0; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline bool operator==(ctype x, rtype y) \
|
||||||
|
{ \
|
||||||
|
return y == x.real && x.imag == 0; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline bool operator==(ctype x, ctype y) \
|
||||||
|
{ \
|
||||||
|
return x.real == y.real && \
|
||||||
|
x.imag == y.imag; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator-(ctype x) \
|
||||||
|
{ \
|
||||||
|
return {-x.real, -x.imag}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator+(rtype x, ctype y) \
|
||||||
|
{ \
|
||||||
|
return {x+y.real, y.imag}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator+(ctype x, rtype y) \
|
||||||
|
{ \
|
||||||
|
return {y+x.real, x.imag}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator+(ctype x, ctype y) \
|
||||||
|
{ \
|
||||||
|
return {x.real+y.real, x.imag+y.imag}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator-(rtype x, ctype y) \
|
||||||
|
{ \
|
||||||
|
return {x-y.real, -y.imag}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator-(ctype x, rtype y) \
|
||||||
|
{ \
|
||||||
|
return {x.real-y, x.imag}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator-(ctype x, ctype y) \
|
||||||
|
{ \
|
||||||
|
return {x.real-y.real, x.imag-y.imag}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator*(rtype x, ctype y) \
|
||||||
|
{ \
|
||||||
|
return {x*y.real, x*y.imag}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator*(ctype x, rtype y) \
|
||||||
|
{ \
|
||||||
|
return {y*x.real, y*x.imag}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator*(ctype x, ctype y) \
|
||||||
|
{ \
|
||||||
|
return {x.real*y.real - x.imag*y.imag, \
|
||||||
|
x.real*y.imag + x.imag*y.real}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator/(rtype x, ctype y) \
|
||||||
|
{ \
|
||||||
|
auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \
|
||||||
|
auto n = std::ilogb(scale); \
|
||||||
|
auto yrs = std::scalbn(y.real, -n); \
|
||||||
|
auto yis = std::scalbn(y.imag, -n); \
|
||||||
|
auto denom = y.real*yrs + y.imag*yis; \
|
||||||
|
return {x*yrs/denom, -x*yis/denom}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator/(ctype x, rtype y) \
|
||||||
|
{ \
|
||||||
|
return {x.real/y, x.imag/y}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator/(ctype x, ctype y) \
|
||||||
|
{ \
|
||||||
|
auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \
|
||||||
|
auto n = std::ilogb(scale); \
|
||||||
|
auto yrs = std::scalbn(y.real, -n); \
|
||||||
|
auto yis = std::scalbn(y.imag, -n); \
|
||||||
|
auto denom = y.real*yrs + y.imag*yis; \
|
||||||
|
return {(x.real*yrs + x.imag*yis)/denom, \
|
||||||
|
(x.imag*yrs - x.real*yis)/denom}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype& operator+=(ctype& x, rtype y) \
|
||||||
|
{ \
|
||||||
|
x.real += y; \
|
||||||
|
return x; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype& operator+=(ctype& x, ctype y) \
|
||||||
|
{ \
|
||||||
|
x.real += y.real; x.imag += y.imag; \
|
||||||
|
return x; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype& operator-=(ctype& x, rtype y) \
|
||||||
|
{ \
|
||||||
|
x.real -= y; \
|
||||||
|
return x; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype& operator-=(ctype& x, ctype y) \
|
||||||
|
{ \
|
||||||
|
x.real -= y.real; x.imag -= y.imag; \
|
||||||
|
return x; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype& operator*=(ctype& x, rtype y) \
|
||||||
|
{ \
|
||||||
|
x.real *= y; x.imag *= y; \
|
||||||
|
return x; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype& operator*=(ctype& x, ctype y) \
|
||||||
|
{ \
|
||||||
|
x = x * y; \
|
||||||
|
return x; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype& operator/=(ctype& x, rtype y) \
|
||||||
|
{ \
|
||||||
|
x.real /= y; x.imag /= y; \
|
||||||
|
return x; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype& operator/=(ctype& x, ctype y) \
|
||||||
|
{ \
|
||||||
|
x = x / y; \
|
||||||
|
return x; \
|
||||||
|
}
|
||||||
|
|
||||||
|
COMPLEX_MATH_OPS(float, scomplex);
|
||||||
|
COMPLEX_MATH_OPS(double, dcomplex);
|
||||||
|
|
||||||
186
test/syrk_diagonal/syrk_diagonal_example.c
Normal file
186
test/syrk_diagonal/syrk_diagonal_example.c
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
#include "syrk_diagonal_ref.h"
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Structure which includes all additional information beyond what is
|
||||||
|
* already stored in the obj_t structure.
|
||||||
|
*
|
||||||
|
* This structure is **read-only** during the operation!
|
||||||
|
*/
|
||||||
|
typedef struct packm_diag_params_t
|
||||||
|
{
|
||||||
|
packm_blk_var1_params_t super;
|
||||||
|
void* d;
|
||||||
|
inc_t incd;
|
||||||
|
} packm_diag_params_t;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Declare the pack kernel type and set up and array of
|
||||||
|
* packing kernels, one for each data type.
|
||||||
|
*/
|
||||||
|
#undef GENTFUNC
|
||||||
|
#define GENTFUNC(ctype,ch,op) \
|
||||||
|
void PASTEMAC(ch,op) \
|
||||||
|
( \
|
||||||
|
struc_t struca, \
|
||||||
|
diag_t diaga, \
|
||||||
|
uplo_t uploa, \
|
||||||
|
conj_t conja, \
|
||||||
|
pack_t schema, \
|
||||||
|
bool invdiag, \
|
||||||
|
dim_t panel_dim, \
|
||||||
|
dim_t panel_len, \
|
||||||
|
dim_t panel_dim_max, \
|
||||||
|
dim_t panel_len_max, \
|
||||||
|
dim_t panel_dim_off, \
|
||||||
|
dim_t panel_len_off, \
|
||||||
|
void* restrict kappa, \
|
||||||
|
void* restrict a, inc_t inca, inc_t lda, \
|
||||||
|
void* restrict p, inc_t ldp, \
|
||||||
|
inc_t is_p, \
|
||||||
|
cntx_t* cntx, \
|
||||||
|
void* params \
|
||||||
|
) \
|
||||||
|
{ \
|
||||||
|
packm_diag_params_t* params_cast = params; \
|
||||||
|
ctype* restrict a_cast = a; \
|
||||||
|
ctype* restrict p_cast = p; \
|
||||||
|
ctype* restrict d_cast = params_cast->d; \
|
||||||
|
inc_t incd = params_cast->incd; \
|
||||||
|
ctype kappa_cast = *( ctype* )kappa; \
|
||||||
|
\
|
||||||
|
if ( schema != BLIS_PACKED_ROW_PANELS && \
|
||||||
|
schema != BLIS_PACKED_COL_PANELS ) \
|
||||||
|
bli_abort(); \
|
||||||
|
\
|
||||||
|
/* Apply the offset */ \
|
||||||
|
d_cast += panel_len_off * incd; \
|
||||||
|
\
|
||||||
|
if ( conja ) \
|
||||||
|
{ \
|
||||||
|
for ( dim_t j = 0; j < panel_len; j++ ) \
|
||||||
|
{ \
|
||||||
|
ctype kappa_d; \
|
||||||
|
PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \
|
||||||
|
\
|
||||||
|
for (dim_t i = 0;i < panel_dim;i++) \
|
||||||
|
PASTEMAC(ch,scal2js)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \
|
||||||
|
\
|
||||||
|
for (dim_t i = panel_dim;i < panel_dim_max;i++) \
|
||||||
|
PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
else \
|
||||||
|
{ \
|
||||||
|
for ( dim_t j = 0; j < panel_len; j++ ) \
|
||||||
|
{ \
|
||||||
|
ctype kappa_d; \
|
||||||
|
PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \
|
||||||
|
\
|
||||||
|
for (dim_t i = 0;i < panel_dim;i++) \
|
||||||
|
PASTEMAC(ch,scal2s)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \
|
||||||
|
\
|
||||||
|
for (dim_t i = panel_dim;i < panel_dim_max;i++) \
|
||||||
|
PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
for (dim_t j = panel_len;j < panel_len_max;j++) \
|
||||||
|
for (dim_t i = 0;i < panel_dim_max;i++) \
|
||||||
|
PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \
|
||||||
|
}
|
||||||
|
|
||||||
|
INSERT_GENTFUNC_BASIC0(packm_diag_ukr);
|
||||||
|
|
||||||
|
static packm_ker_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr );
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Modify the object A to include information about the diagonal D,
|
||||||
|
* and imbue it with special function pointers which will take care
|
||||||
|
* of the actual work of forming (D * A^T)
|
||||||
|
*/
|
||||||
|
void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a )
|
||||||
|
{
|
||||||
|
memset( params, 0, sizeof(*params) );
|
||||||
|
|
||||||
|
// Assumes D is a column vector
|
||||||
|
params->d = bli_obj_buffer_at_off( d );
|
||||||
|
params->incd = bli_obj_row_stride( d );
|
||||||
|
|
||||||
|
for ( int i = BLIS_DT_LO; i <= BLIS_DT_HI; i++ )
|
||||||
|
params->super.ukr_fn[i][i] = packm_diag_ukrs[i];
|
||||||
|
|
||||||
|
// Attach the parameters to the A object.
|
||||||
|
bli_obj_set_pack_params( params, a );
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Implements C := alpha * A * D * A^T + beta * C
|
||||||
|
*
|
||||||
|
* where D is a diagonal matrix with elements taken from the "d" vector.
|
||||||
|
*/
|
||||||
|
void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c )
|
||||||
|
{
|
||||||
|
obj_t ad; // this is (D * A^T)
|
||||||
|
packm_diag_params_t params;
|
||||||
|
|
||||||
|
bli_obj_alias_to( a, &ad );
|
||||||
|
bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T
|
||||||
|
attach_diagonal_factor( ¶ms, d, &ad );
|
||||||
|
|
||||||
|
// Does C := alpha * A * B + beta * C using B = (D + A^T)
|
||||||
|
bli_gemmtnat( alpha, a, &ad, beta, c, NULL, NULL );
|
||||||
|
}
|
||||||
|
|
||||||
|
int main( void )
|
||||||
|
{
|
||||||
|
obj_t a;
|
||||||
|
obj_t d;
|
||||||
|
obj_t c;
|
||||||
|
obj_t c_copy;
|
||||||
|
obj_t norm;
|
||||||
|
|
||||||
|
dim_t m = 10;
|
||||||
|
dim_t k = 10;
|
||||||
|
|
||||||
|
for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ )
|
||||||
|
for ( int upper = 0; upper <= 1; upper++ )
|
||||||
|
for ( int transa = 0; transa <= 1; transa++ )
|
||||||
|
for ( int transc = 0; transc <= 1; transc++ )
|
||||||
|
{
|
||||||
|
num_t dt = dt_;
|
||||||
|
uplo_t uplo = upper ? BLIS_UPPER : BLIS_LOWER;
|
||||||
|
|
||||||
|
bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a );
|
||||||
|
bli_obj_create( dt, k, 1, 1, 1, &d );
|
||||||
|
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c );
|
||||||
|
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy );
|
||||||
|
bli_obj_set_struc( BLIS_SYMMETRIC , &c );
|
||||||
|
bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy );
|
||||||
|
bli_obj_set_uplo( uplo , &c );
|
||||||
|
bli_obj_set_uplo( uplo , &c_copy );
|
||||||
|
bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm );
|
||||||
|
|
||||||
|
bli_randm( &a );
|
||||||
|
bli_randm( &d );
|
||||||
|
bli_randm( &c );
|
||||||
|
bli_copym( &c, &c_copy );
|
||||||
|
|
||||||
|
syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c );
|
||||||
|
syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy );
|
||||||
|
|
||||||
|
bli_subm( &c_copy, &c );
|
||||||
|
bli_normfm( &c, &norm );
|
||||||
|
|
||||||
|
double normr, normi;
|
||||||
|
bli_getsc( &norm, &normr, &normi );
|
||||||
|
|
||||||
|
printf( "dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n",
|
||||||
|
dt, upper, transa, transc, normr );
|
||||||
|
|
||||||
|
bli_obj_free( &a );
|
||||||
|
bli_obj_free( &d );
|
||||||
|
bli_obj_free( &c );
|
||||||
|
bli_obj_free( &c_copy );
|
||||||
|
bli_obj_free( &norm );
|
||||||
|
}
|
||||||
|
}
|
||||||
220
test/syrk_diagonal/syrk_diagonal_example.cxx
Normal file
220
test/syrk_diagonal/syrk_diagonal_example.cxx
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
#include "syrk_diagonal_ref.h"
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Forward-declare the pack kernel type and set up and array of
|
||||||
|
* packing kernels, one for each data type.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
void packm_diag_ukr
|
||||||
|
(
|
||||||
|
struc_t /*struca*/,
|
||||||
|
diag_t /*diaga*/,
|
||||||
|
uplo_t /*uploa*/,
|
||||||
|
conj_t conja,
|
||||||
|
pack_t schema,
|
||||||
|
bool /*invdiag*/,
|
||||||
|
dim_t panel_dim,
|
||||||
|
dim_t panel_len,
|
||||||
|
dim_t panel_dim_max,
|
||||||
|
dim_t panel_len_max,
|
||||||
|
dim_t /*panel_dim_off*/,
|
||||||
|
dim_t panel_len_off,
|
||||||
|
void* restrict kappa,
|
||||||
|
void* restrict a, inc_t inca, inc_t lda,
|
||||||
|
void* restrict p, inc_t ldp,
|
||||||
|
inc_t /*is_p*/,
|
||||||
|
cntx_t* /*cntx*/,
|
||||||
|
void* params
|
||||||
|
);
|
||||||
|
|
||||||
|
#undef GENTFUNC
|
||||||
|
#define GENTFUNC(ctype,ch,op) \
|
||||||
|
static auto PASTEMAC(ch,op) = &packm_diag_ukr<ctype>;
|
||||||
|
|
||||||
|
INSERT_GENTFUNC_BASIC0(packm_diag_ukr);
|
||||||
|
|
||||||
|
static packm_ker_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr );
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Structure which includes all additional information beyond what is
|
||||||
|
* already stored in the obj_t structure.
|
||||||
|
*
|
||||||
|
* This structure is **read-only** during the operation!
|
||||||
|
*/
|
||||||
|
struct packm_diag_params_t : packm_blk_var1_params_t
|
||||||
|
{
|
||||||
|
void* d;
|
||||||
|
inc_t incd;
|
||||||
|
|
||||||
|
packm_diag_params_t() {}
|
||||||
|
|
||||||
|
packm_diag_params_t( void* d, inc_t incd )
|
||||||
|
: d(d), incd(incd)
|
||||||
|
{
|
||||||
|
for ( int i = BLIS_DT_LO; i <= BLIS_DT_HI; i++ )
|
||||||
|
ukr_fn[i][i] = packm_diag_ukrs[i];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Selecting a different kernel based on the current architecture is
|
||||||
|
* currently not possible, but is something we plan to support.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
void packm_diag_ukr
|
||||||
|
(
|
||||||
|
struc_t /*struca*/,
|
||||||
|
diag_t /*diaga*/,
|
||||||
|
uplo_t /*uploa*/,
|
||||||
|
conj_t conja,
|
||||||
|
pack_t schema,
|
||||||
|
bool /*invdiag*/,
|
||||||
|
dim_t panel_dim,
|
||||||
|
dim_t panel_len,
|
||||||
|
dim_t panel_dim_max,
|
||||||
|
dim_t panel_len_max,
|
||||||
|
dim_t /*panel_dim_off*/,
|
||||||
|
dim_t panel_len_off,
|
||||||
|
void* restrict kappa,
|
||||||
|
void* restrict a, inc_t inca, inc_t lda,
|
||||||
|
void* restrict p, inc_t ldp,
|
||||||
|
inc_t /*is_p*/,
|
||||||
|
cntx_t* /*cntx*/,
|
||||||
|
void* params
|
||||||
|
)
|
||||||
|
{
|
||||||
|
auto params_cast = ( packm_diag_params_t* )params;
|
||||||
|
T* restrict a_cast = ( T* )a;
|
||||||
|
T* restrict p_cast = ( T* )p;
|
||||||
|
T* restrict d_cast = ( T* )params_cast->d;
|
||||||
|
auto incd = params_cast->incd;
|
||||||
|
auto kappa_cast = *( T* )kappa;
|
||||||
|
|
||||||
|
if ( schema != BLIS_PACKED_ROW_PANELS &&
|
||||||
|
schema != BLIS_PACKED_COL_PANELS )
|
||||||
|
bli_abort();
|
||||||
|
|
||||||
|
/* Apply the offset */
|
||||||
|
d_cast += panel_len_off * incd;
|
||||||
|
|
||||||
|
if ( conja )
|
||||||
|
{
|
||||||
|
for ( dim_t j = 0; j < panel_len; j++ )
|
||||||
|
{
|
||||||
|
auto kappa_d = kappa_cast * d_cast[ j*incd ];
|
||||||
|
|
||||||
|
for (dim_t i = 0;i < panel_dim;i++)
|
||||||
|
p_cast[ i + j*ldp ] = kappa_d * conj( a_cast[ i*inca + j*lda ] );
|
||||||
|
|
||||||
|
for (dim_t i = panel_dim;i < panel_dim_max;i++)
|
||||||
|
p_cast[ i + j*ldp ] = convert<T>(0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for ( dim_t j = 0; j < panel_len; j++ )
|
||||||
|
{
|
||||||
|
auto kappa_d = kappa_cast * d_cast[ j*incd ];
|
||||||
|
|
||||||
|
for (dim_t i = 0;i < panel_dim;i++)
|
||||||
|
p_cast[ i + j*ldp ] = kappa_d * a_cast[ i*inca + j*lda ];
|
||||||
|
|
||||||
|
for (dim_t i = panel_dim;i < panel_dim_max;i++)
|
||||||
|
p_cast[ i + j*ldp ] = convert<T>(0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (dim_t j = panel_len;j < panel_len_max;j++)
|
||||||
|
for (dim_t i = 0;i < panel_dim_max;i++)
|
||||||
|
p_cast[ i + j*ldp ] = convert<T>(0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Modify the object A to include information about the diagonal D,
|
||||||
|
* and imbue it with special function pointers which will take care
|
||||||
|
* of the actual work of forming (D * A^T)
|
||||||
|
*/
|
||||||
|
void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a )
|
||||||
|
{
|
||||||
|
// Assumes D is a column vector
|
||||||
|
new (params) packm_diag_params_t
|
||||||
|
(
|
||||||
|
bli_obj_buffer_at_off( d ),
|
||||||
|
bli_obj_row_stride( d )
|
||||||
|
);
|
||||||
|
|
||||||
|
// Attach the parameters to the A object.
|
||||||
|
bli_obj_set_pack_params( params, a );
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Implements C := alpha * A * D * A^T + beta * C
|
||||||
|
*
|
||||||
|
* where D is a diagonal matrix with elements taken from the "d" vector.
|
||||||
|
*/
|
||||||
|
void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c )
|
||||||
|
{
|
||||||
|
obj_t ad; // this is (D * A^T)
|
||||||
|
packm_diag_params_t params;
|
||||||
|
|
||||||
|
bli_obj_alias_to( a, &ad );
|
||||||
|
bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T
|
||||||
|
attach_diagonal_factor( ¶ms, d, &ad );
|
||||||
|
|
||||||
|
// Does C := alpha * A * B + beta * C using B = (D + A^T)
|
||||||
|
bli_gemmtnat( alpha, a, &ad, beta, c, NULL, NULL );
|
||||||
|
}
|
||||||
|
|
||||||
|
int main()
|
||||||
|
{
|
||||||
|
obj_t a;
|
||||||
|
obj_t d;
|
||||||
|
obj_t c;
|
||||||
|
obj_t c_copy;
|
||||||
|
obj_t norm;
|
||||||
|
|
||||||
|
auto m = 10;
|
||||||
|
auto k = 10;
|
||||||
|
|
||||||
|
for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ )
|
||||||
|
for ( int upper = 0; upper <= 1; upper++ )
|
||||||
|
for ( int transa = 0; transa <= 1; transa++ )
|
||||||
|
for ( int transc = 0; transc <= 1; transc++ )
|
||||||
|
{
|
||||||
|
auto dt = ( num_t )dt_;
|
||||||
|
auto uplo = upper ? BLIS_UPPER : BLIS_LOWER;
|
||||||
|
|
||||||
|
bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a );
|
||||||
|
bli_obj_create( dt, k, 1, 1, 1, &d );
|
||||||
|
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c );
|
||||||
|
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy );
|
||||||
|
bli_obj_set_struc( BLIS_SYMMETRIC , &c );
|
||||||
|
bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy );
|
||||||
|
bli_obj_set_uplo( uplo , &c );
|
||||||
|
bli_obj_set_uplo( uplo , &c_copy );
|
||||||
|
bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm );
|
||||||
|
|
||||||
|
bli_randm( &a );
|
||||||
|
bli_randm( &d );
|
||||||
|
bli_randm( &c );
|
||||||
|
bli_copym( &c, &c_copy );
|
||||||
|
|
||||||
|
syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c );
|
||||||
|
syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy );
|
||||||
|
|
||||||
|
bli_subm( &c_copy, &c );
|
||||||
|
bli_normfm( &c, &norm );
|
||||||
|
|
||||||
|
double normr, normi;
|
||||||
|
bli_getsc( &norm, &normr, &normi );
|
||||||
|
|
||||||
|
printf("dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n",
|
||||||
|
dt, upper, transa, transc, normr);
|
||||||
|
|
||||||
|
bli_obj_free( &a );
|
||||||
|
bli_obj_free( &d );
|
||||||
|
bli_obj_free( &c );
|
||||||
|
bli_obj_free( &c_copy );
|
||||||
|
bli_obj_free( &norm );
|
||||||
|
}
|
||||||
|
}
|
||||||
354
test/syrk_diagonal/syrk_diagonal_example2.c
Normal file
354
test/syrk_diagonal/syrk_diagonal_example2.c
Normal file
@@ -0,0 +1,354 @@
|
|||||||
|
#include "syrk_diagonal_ref.h"
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Structure which includes all additional information beyond what is
|
||||||
|
* already stored in the obj_t structure.
|
||||||
|
*
|
||||||
|
* This structure is **read-only** during the operation!
|
||||||
|
*/
|
||||||
|
typedef struct packm_diag_params_t
|
||||||
|
{
|
||||||
|
void* d;
|
||||||
|
inc_t incd;
|
||||||
|
} packm_diag_params_t;
|
||||||
|
|
||||||
|
typedef void (*packm_diag_ukr_vft)
|
||||||
|
(
|
||||||
|
bool conja,
|
||||||
|
dim_t panel_dim,
|
||||||
|
dim_t panel_len,
|
||||||
|
dim_t panel_dim_max,
|
||||||
|
dim_t panel_len_max,
|
||||||
|
void* restrict kappa,
|
||||||
|
void* restrict d, inc_t incd,
|
||||||
|
void* restrict a, inc_t inca, inc_t lda,
|
||||||
|
void* restrict p, inc_t ldp
|
||||||
|
);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Declare the pack kernel type and set up and array of
|
||||||
|
* packing kernels, one for each data type.
|
||||||
|
*/
|
||||||
|
#undef GENTFUNC
|
||||||
|
#define GENTFUNC(ctype,ch,op) \
|
||||||
|
void PASTEMAC(ch,op) \
|
||||||
|
( \
|
||||||
|
bool conja, \
|
||||||
|
dim_t panel_dim, \
|
||||||
|
dim_t panel_len, \
|
||||||
|
dim_t panel_dim_max, \
|
||||||
|
dim_t panel_len_max, \
|
||||||
|
void* restrict kappa, \
|
||||||
|
void* restrict d, inc_t incd, \
|
||||||
|
void* restrict a, inc_t inca, inc_t lda, \
|
||||||
|
void* restrict p, inc_t ldp \
|
||||||
|
) \
|
||||||
|
{ \
|
||||||
|
ctype* restrict a_cast = a; \
|
||||||
|
ctype* restrict p_cast = p; \
|
||||||
|
ctype* restrict d_cast = d; \
|
||||||
|
ctype kappa_cast = *( ctype* )kappa; \
|
||||||
|
\
|
||||||
|
if ( conja ) \
|
||||||
|
{ \
|
||||||
|
for ( dim_t j = 0; j < panel_len; j++ ) \
|
||||||
|
{ \
|
||||||
|
ctype kappa_d; \
|
||||||
|
PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \
|
||||||
|
\
|
||||||
|
for (dim_t i = 0;i < panel_dim;i++) \
|
||||||
|
PASTEMAC(ch,scal2js)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \
|
||||||
|
\
|
||||||
|
for (dim_t i = panel_dim;i < panel_dim_max;i++) \
|
||||||
|
PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
else \
|
||||||
|
{ \
|
||||||
|
for ( dim_t j = 0; j < panel_len; j++ ) \
|
||||||
|
{ \
|
||||||
|
ctype kappa_d; \
|
||||||
|
PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \
|
||||||
|
\
|
||||||
|
for (dim_t i = 0;i < panel_dim;i++) \
|
||||||
|
PASTEMAC(ch,scal2s)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \
|
||||||
|
\
|
||||||
|
for (dim_t i = panel_dim;i < panel_dim_max;i++) \
|
||||||
|
PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
for (dim_t j = panel_len;j < panel_len_max;j++) \
|
||||||
|
for (dim_t i = 0;i < panel_dim_max;i++) \
|
||||||
|
PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \
|
||||||
|
}
|
||||||
|
|
||||||
|
INSERT_GENTFUNC_BASIC0(packm_diag_ukr);
|
||||||
|
|
||||||
|
static packm_diag_ukr_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr );
|
||||||
|
|
||||||
|
void packm_diag
|
||||||
|
(
|
||||||
|
obj_t* a,
|
||||||
|
obj_t* p,
|
||||||
|
cntx_t* cntx,
|
||||||
|
rntm_t* rntm,
|
||||||
|
cntl_t* cntl,
|
||||||
|
thrinfo_t* thread
|
||||||
|
)
|
||||||
|
{
|
||||||
|
#if 1
|
||||||
|
|
||||||
|
// We begin by copying the fields of A.
|
||||||
|
bli_obj_alias_to( a, p );
|
||||||
|
|
||||||
|
// Get information about data types.
|
||||||
|
num_t dt = bli_obj_dt( a );
|
||||||
|
num_t dt_tar = bli_obj_target_dt( a );
|
||||||
|
num_t dt_scalar = bli_obj_scalar_dt( a );
|
||||||
|
dim_t dt_size = bli_dt_size( dt );
|
||||||
|
|
||||||
|
if ( dt_scalar != dt || dt_tar != dt )
|
||||||
|
bli_abort();
|
||||||
|
|
||||||
|
// Extract various fields from the control tree.
|
||||||
|
bszid_t bmult_id_m = bli_cntl_packm_params_bmid_m( cntl );
|
||||||
|
bszid_t bmult_id_n = bli_cntl_packm_params_bmid_n( cntl );
|
||||||
|
pack_t schema = bli_cntl_packm_params_pack_schema( cntl );
|
||||||
|
dim_t bmult_m_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_m, cntx );
|
||||||
|
dim_t bmult_m_pack = bli_cntx_get_blksz_max_dt( dt_tar, bmult_id_m, cntx );
|
||||||
|
dim_t bmult_n_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_n, cntx );
|
||||||
|
|
||||||
|
if ( schema != BLIS_PACKED_ROW_PANELS &&
|
||||||
|
schema != BLIS_PACKED_COL_PANELS )
|
||||||
|
bli_abort();
|
||||||
|
|
||||||
|
// Store the pack schema to the object.
|
||||||
|
bli_obj_set_pack_schema( schema, p );
|
||||||
|
|
||||||
|
// Clear the conjugation field from the object since matrix packing
|
||||||
|
// in BLIS is deemed to take care of all conjugation necessary.
|
||||||
|
bli_obj_set_conj( BLIS_NO_CONJUGATE, p );
|
||||||
|
|
||||||
|
// If we are packing micropanels, mark P as dense.
|
||||||
|
bli_obj_set_uplo( BLIS_DENSE, p );
|
||||||
|
|
||||||
|
// Reset the view offsets to (0,0).
|
||||||
|
bli_obj_set_offs( 0, 0, p );
|
||||||
|
|
||||||
|
// Compute the dimensions padded by the dimension multiples. These
|
||||||
|
// dimensions will be the dimensions of the packed matrices, including
|
||||||
|
// zero-padding, and will be used by the macro- and micro-kernels.
|
||||||
|
// We compute them by starting with the effective dimensions of A (now
|
||||||
|
// in P) and aligning them to the dimension multiples (typically equal
|
||||||
|
// to register blocksizes). This does waste a little bit of space for
|
||||||
|
// level-2 operations, but that's okay with us.
|
||||||
|
dim_t m_p = bli_obj_length( p );
|
||||||
|
dim_t n_p = bli_obj_width( p );
|
||||||
|
dim_t m_p_pad = bli_align_dim_to_mult( m_p, bmult_m_def );
|
||||||
|
dim_t n_p_pad = bli_align_dim_to_mult( n_p, bmult_n_def );
|
||||||
|
|
||||||
|
// Save the padded dimensions into the packed object. It is important
|
||||||
|
// to save these dimensions since they represent the actual dimensions
|
||||||
|
// of the zero-padded matrix.
|
||||||
|
bli_obj_set_padded_dims( m_p_pad, n_p_pad, p );
|
||||||
|
|
||||||
|
// The "panel stride" of a micropanel packed object is interpreted as
|
||||||
|
// the distance between the (0,0) element of panel k and the (0,0)
|
||||||
|
// element of panel k+1. We use the padded width computed above to
|
||||||
|
// allow for zero-padding (if necessary/desired) along the far end
|
||||||
|
// of each micropanel (ie: the right edge of the matrix). Zero-padding
|
||||||
|
// can also occur along the long edge of the last micropanel if the m
|
||||||
|
// dimension of the matrix is not a whole multiple of MR.
|
||||||
|
inc_t ps_p = bmult_m_pack * n_p_pad;
|
||||||
|
|
||||||
|
/* Compute the total number of iterations we'll need. */
|
||||||
|
dim_t n_iter = m_p_pad / bmult_m_def;
|
||||||
|
|
||||||
|
// Store the strides and panel dimension in P.
|
||||||
|
bli_obj_set_strides( 1, bmult_m_pack, p );
|
||||||
|
bli_obj_set_imag_stride( 1, p );
|
||||||
|
bli_obj_set_panel_dim( bmult_m_def, p );
|
||||||
|
bli_obj_set_panel_stride( ps_p, p );
|
||||||
|
bli_obj_set_panel_length( bmult_m_def, p );
|
||||||
|
bli_obj_set_panel_width( n_p, p );
|
||||||
|
|
||||||
|
// Compute the size of the packed buffer.
|
||||||
|
siz_t size_p = ps_p * n_iter * dt_size;
|
||||||
|
if ( size_p == 0 ) return;
|
||||||
|
|
||||||
|
// Update the buffer address in p to point to the buffer associated
|
||||||
|
// with the mem_t entry acquired from the memory broker (now cached in
|
||||||
|
// the control tree node).
|
||||||
|
char* p_cast = (char*)bli_packm_alloc( size_p, rntm, cntl, thread );
|
||||||
|
bli_obj_set_buffer( p_cast, p );
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
// Every thread initializes p and determines the size of memory
|
||||||
|
// block needed (which gets embedded into the otherwise "blank" mem_t
|
||||||
|
// entry in the control tree node). Return early if no packing is required.
|
||||||
|
if ( !bli_packm_init( a, p, cntx, rntm, cntl, thread ) )
|
||||||
|
return;
|
||||||
|
|
||||||
|
num_t dt = bli_obj_dt( a );
|
||||||
|
dim_t dt_size = bli_dt_size( dt );
|
||||||
|
|
||||||
|
bszid_t bmult_id_m = bli_cntl_packm_params_bmid_m( cntl );
|
||||||
|
dim_t bmult_m_def = bli_cntx_get_blksz_def_dt( dt, bmult_id_m, cntx );
|
||||||
|
dim_t bmult_m_pack = bli_cntx_get_blksz_max_dt( dt, bmult_id_m, cntx );
|
||||||
|
|
||||||
|
dim_t m_p = bli_obj_length( p );
|
||||||
|
dim_t n_p = bli_obj_width( p );
|
||||||
|
dim_t m_p_pad = bli_obj_padded_length( p );
|
||||||
|
dim_t n_p_pad = bli_obj_padded_width( p );
|
||||||
|
dim_t n_iter = m_p_pad / bmult_m_def;
|
||||||
|
|
||||||
|
char* p_cast = bli_obj_buffer( p );
|
||||||
|
inc_t ps_p = bli_obj_panel_stride( p );
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
char* a_cast = bli_obj_buffer_at_off( a );
|
||||||
|
inc_t inca = bli_obj_row_stride( a );
|
||||||
|
inc_t lda = bli_obj_col_stride( a );
|
||||||
|
dim_t panel_len_off = bli_obj_col_off( a );
|
||||||
|
conj_t conja = bli_obj_conj_status( a );
|
||||||
|
|
||||||
|
packm_diag_params_t* params = bli_obj_pack_params( a );
|
||||||
|
char* d_cast = params->d;
|
||||||
|
inc_t incd = params->incd;
|
||||||
|
|
||||||
|
obj_t kappa_local;
|
||||||
|
char* kappa_cast = bli_packm_scalar( &kappa_local, p );
|
||||||
|
|
||||||
|
packm_diag_ukr_vft packm_ker_cast = packm_diag_ukrs[ dt ];
|
||||||
|
|
||||||
|
/* Query the number of threads and thread ids from the current thread's
|
||||||
|
packm thrinfo_t node. */
|
||||||
|
const dim_t nt = bli_thread_n_way( thread );
|
||||||
|
const dim_t tid = bli_thread_work_id( thread );
|
||||||
|
|
||||||
|
/* Determine the thread range and increment using the current thread's
|
||||||
|
packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir()
|
||||||
|
will depend on whether slab or round-robin partitioning was requested
|
||||||
|
at configure-time. */
|
||||||
|
dim_t it_start, it_end, it_inc;
|
||||||
|
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc );
|
||||||
|
|
||||||
|
/* Iterate over every logical micropanel in the source matrix. */
|
||||||
|
for ( dim_t it = 0; it < n_iter; it += 1 )
|
||||||
|
{
|
||||||
|
dim_t panel_dim_i = bli_min( bmult_m_def, m_p - it*bmult_m_def );
|
||||||
|
|
||||||
|
char* d_begin = d_cast + panel_len_off*incd*dt_size;
|
||||||
|
char* a_begin = a_cast + it* bmult_m_def*inca*dt_size;
|
||||||
|
char* p_begin = p_cast + it* ps_p*dt_size;
|
||||||
|
|
||||||
|
if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) )
|
||||||
|
{
|
||||||
|
packm_ker_cast
|
||||||
|
(
|
||||||
|
conja,
|
||||||
|
panel_dim_i,
|
||||||
|
n_p,
|
||||||
|
bmult_m_def,
|
||||||
|
n_p_pad,
|
||||||
|
kappa_cast,
|
||||||
|
d_begin, incd,
|
||||||
|
a_begin, inca, lda,
|
||||||
|
p_begin, bmult_m_pack
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Modify the object A to include information about the diagonal D,
|
||||||
|
* and imbue it with special function pointers which will take care
|
||||||
|
* of the actual work of forming (D * A^T)
|
||||||
|
*/
|
||||||
|
void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a )
|
||||||
|
{
|
||||||
|
// Assumes D is a column vector
|
||||||
|
params->d = bli_obj_buffer_at_off( d );
|
||||||
|
params->incd = bli_obj_row_stride( d );
|
||||||
|
|
||||||
|
// Set the custom pack function.
|
||||||
|
bli_obj_set_pack_fn( packm_diag, a );
|
||||||
|
|
||||||
|
// Attach the parameters to the A object.
|
||||||
|
bli_obj_set_pack_params( params, a );
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Implements C := alpha * A * D * A^T + beta * C
|
||||||
|
*
|
||||||
|
* where D is a diagonal matrix with elements taken from the "d" vector.
|
||||||
|
*/
|
||||||
|
void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c )
|
||||||
|
{
|
||||||
|
obj_t ad; // this is (D * A^T)
|
||||||
|
packm_diag_params_t params;
|
||||||
|
|
||||||
|
bli_obj_alias_to( a, &ad );
|
||||||
|
bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T
|
||||||
|
attach_diagonal_factor( ¶ms, d, &ad );
|
||||||
|
|
||||||
|
// Does C := alpha * A * B + beta * C using B = (D + A^T)
|
||||||
|
bli_gemmt( alpha, a, &ad, beta, c );
|
||||||
|
}
|
||||||
|
|
||||||
|
int main( void )
|
||||||
|
{
|
||||||
|
obj_t a;
|
||||||
|
obj_t d;
|
||||||
|
obj_t c;
|
||||||
|
obj_t c_copy;
|
||||||
|
obj_t norm;
|
||||||
|
|
||||||
|
dim_t m = 10;
|
||||||
|
dim_t k = 10;
|
||||||
|
|
||||||
|
for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ )
|
||||||
|
for ( int upper = 0; upper <= 1; upper++ )
|
||||||
|
for ( int transa = 0; transa <= 1; transa++ )
|
||||||
|
for ( int transc = 0; transc <= 1; transc++ )
|
||||||
|
{
|
||||||
|
num_t dt = dt_;
|
||||||
|
uplo_t uplo = upper ? BLIS_UPPER : BLIS_LOWER;
|
||||||
|
|
||||||
|
bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a );
|
||||||
|
bli_obj_create( dt, k, 1, 1, 1, &d );
|
||||||
|
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c );
|
||||||
|
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy );
|
||||||
|
bli_obj_set_struc( BLIS_SYMMETRIC , &c );
|
||||||
|
bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy );
|
||||||
|
bli_obj_set_uplo( uplo , &c );
|
||||||
|
bli_obj_set_uplo( uplo , &c_copy );
|
||||||
|
bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm );
|
||||||
|
|
||||||
|
bli_randm( &a );
|
||||||
|
bli_randm( &d );
|
||||||
|
bli_randm( &c );
|
||||||
|
bli_copym( &c, &c_copy );
|
||||||
|
|
||||||
|
syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c );
|
||||||
|
syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy );
|
||||||
|
|
||||||
|
bli_subm( &c_copy, &c );
|
||||||
|
bli_normfm( &c, &norm );
|
||||||
|
|
||||||
|
double normr, normi;
|
||||||
|
bli_getsc( &norm, &normr, &normi );
|
||||||
|
|
||||||
|
printf( "dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n",
|
||||||
|
dt, upper, transa, transc, normr );
|
||||||
|
|
||||||
|
bli_obj_free( &a );
|
||||||
|
bli_obj_free( &d );
|
||||||
|
bli_obj_free( &c );
|
||||||
|
bli_obj_free( &c_copy );
|
||||||
|
bli_obj_free( &norm );
|
||||||
|
}
|
||||||
|
}
|
||||||
338
test/syrk_diagonal/syrk_diagonal_example2.cxx
Normal file
338
test/syrk_diagonal/syrk_diagonal_example2.cxx
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
#include "syrk_diagonal_ref.h"
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Forward-declare the pack kernel type and set up and array of
|
||||||
|
* packing kernels, one for each data type.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
void packm_diag_ukr
|
||||||
|
(
|
||||||
|
bool conja,
|
||||||
|
dim_t panel_dim,
|
||||||
|
dim_t panel_len,
|
||||||
|
dim_t panel_dim_max,
|
||||||
|
dim_t panel_len_max,
|
||||||
|
void* restrict kappa,
|
||||||
|
void* restrict d, inc_t incd,
|
||||||
|
void* restrict a, inc_t inca, inc_t lda,
|
||||||
|
void* restrict p, inc_t ldp
|
||||||
|
);
|
||||||
|
|
||||||
|
#undef GENTFUNC
|
||||||
|
#define GENTFUNC(ctype,ch,op) \
|
||||||
|
static auto PASTEMAC(ch,op) = &packm_diag_ukr<ctype>;
|
||||||
|
|
||||||
|
INSERT_GENTFUNC_BASIC0(packm_diag_ukr);
|
||||||
|
|
||||||
|
using packm_diag_ukr_vft = decltype(&packm_diag_ukr<void>);
|
||||||
|
static packm_diag_ukr_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr );
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Structure which includes all additional information beyond what is
|
||||||
|
* already stored in the obj_t structure.
|
||||||
|
*
|
||||||
|
* This structure is **read-only** during the operation!
|
||||||
|
*/
|
||||||
|
struct packm_diag_params_t
|
||||||
|
{
|
||||||
|
void* d;
|
||||||
|
inc_t incd;
|
||||||
|
|
||||||
|
packm_diag_params_t() {}
|
||||||
|
|
||||||
|
packm_diag_params_t( void* d, inc_t incd )
|
||||||
|
: d(d), incd(incd) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Selecting a different kernel based on the current architecture is
|
||||||
|
* currently not possible, but is something we plan to support.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
void packm_diag_ukr
|
||||||
|
(
|
||||||
|
bool conja,
|
||||||
|
dim_t panel_dim,
|
||||||
|
dim_t panel_len,
|
||||||
|
dim_t panel_dim_max,
|
||||||
|
dim_t panel_len_max,
|
||||||
|
void* restrict kappa,
|
||||||
|
void* restrict d, inc_t incd,
|
||||||
|
void* restrict a, inc_t inca, inc_t lda,
|
||||||
|
void* restrict p, inc_t ldp
|
||||||
|
)
|
||||||
|
{
|
||||||
|
T* restrict a_cast = ( T* )a;
|
||||||
|
T* restrict p_cast = ( T* )p;
|
||||||
|
T* restrict d_cast = ( T* )d;
|
||||||
|
auto kappa_cast = *( T* )kappa;
|
||||||
|
|
||||||
|
if ( conja )
|
||||||
|
{
|
||||||
|
for ( dim_t j = 0; j < panel_len; j++ )
|
||||||
|
{
|
||||||
|
auto kappa_d = kappa_cast * d_cast[ j*incd ];
|
||||||
|
|
||||||
|
for (dim_t i = 0;i < panel_dim;i++)
|
||||||
|
p_cast[ i + j*ldp ] = kappa_d * conj( a_cast[ i*inca + j*lda ] );
|
||||||
|
|
||||||
|
for (dim_t i = panel_dim;i < panel_dim_max;i++)
|
||||||
|
p_cast[ i + j*ldp ] = convert<T>(0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for ( dim_t j = 0; j < panel_len; j++ )
|
||||||
|
{
|
||||||
|
auto kappa_d = kappa_cast * d_cast[ j*incd ];
|
||||||
|
|
||||||
|
for (dim_t i = 0;i < panel_dim;i++)
|
||||||
|
p_cast[ i + j*ldp ] = kappa_d * a_cast[ i*inca + j*lda ];
|
||||||
|
|
||||||
|
for (dim_t i = panel_dim;i < panel_dim_max;i++)
|
||||||
|
p_cast[ i + j*ldp ] = convert<T>(0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (dim_t j = panel_len;j < panel_len_max;j++)
|
||||||
|
for (dim_t i = 0;i < panel_dim_max;i++)
|
||||||
|
p_cast[ i + j*ldp ] = convert<T>(0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
void packm_diag
|
||||||
|
(
|
||||||
|
obj_t* a,
|
||||||
|
obj_t* p,
|
||||||
|
cntx_t* cntx,
|
||||||
|
rntm_t* rntm,
|
||||||
|
cntl_t* cntl,
|
||||||
|
thrinfo_t* thread
|
||||||
|
)
|
||||||
|
{
|
||||||
|
// We begin by copying the fields of A.
|
||||||
|
bli_obj_alias_to( a, p );
|
||||||
|
|
||||||
|
// Get information about data types.
|
||||||
|
num_t dt = bli_obj_dt( a );
|
||||||
|
num_t dt_tar = bli_obj_target_dt( a );
|
||||||
|
num_t dt_scalar = bli_obj_scalar_dt( a );
|
||||||
|
dim_t dt_size = bli_dt_size( dt );
|
||||||
|
|
||||||
|
if ( dt_scalar != dt || dt_tar != dt )
|
||||||
|
bli_abort();
|
||||||
|
|
||||||
|
// Extract various fields from the control tree.
|
||||||
|
bszid_t bmult_id_m = bli_cntl_packm_params_bmid_m( cntl );
|
||||||
|
bszid_t bmult_id_n = bli_cntl_packm_params_bmid_n( cntl );
|
||||||
|
pack_t schema = bli_cntl_packm_params_pack_schema( cntl );
|
||||||
|
dim_t bmult_m_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_m, cntx );
|
||||||
|
dim_t bmult_m_pack = bli_cntx_get_blksz_max_dt( dt_tar, bmult_id_m, cntx );
|
||||||
|
dim_t bmult_n_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_n, cntx );
|
||||||
|
|
||||||
|
if ( schema != BLIS_PACKED_ROW_PANELS &&
|
||||||
|
schema != BLIS_PACKED_COL_PANELS )
|
||||||
|
bli_abort();
|
||||||
|
|
||||||
|
// Store the pack schema to the object.
|
||||||
|
bli_obj_set_pack_schema( schema, p );
|
||||||
|
|
||||||
|
// Clear the conjugation field from the object since matrix packing
|
||||||
|
// in BLIS is deemed to take care of all conjugation necessary.
|
||||||
|
bli_obj_set_conj( BLIS_NO_CONJUGATE, p );
|
||||||
|
|
||||||
|
// If we are packing micropanels, mark P as dense.
|
||||||
|
bli_obj_set_uplo( BLIS_DENSE, p );
|
||||||
|
|
||||||
|
// Reset the view offsets to (0,0).
|
||||||
|
bli_obj_set_offs( 0, 0, p );
|
||||||
|
|
||||||
|
// Compute the dimensions padded by the dimension multiples. These
|
||||||
|
// dimensions will be the dimensions of the packed matrices, including
|
||||||
|
// zero-padding, and will be used by the macro- and micro-kernels.
|
||||||
|
// We compute them by starting with the effective dimensions of A (now
|
||||||
|
// in P) and aligning them to the dimension multiples (typically equal
|
||||||
|
// to register blocksizes). This does waste a little bit of space for
|
||||||
|
// level-2 operations, but that's okay with us.
|
||||||
|
dim_t m_p = bli_obj_length( p );
|
||||||
|
dim_t n_p = bli_obj_width( p );
|
||||||
|
dim_t m_p_pad = bli_align_dim_to_mult( m_p, bmult_m_def );
|
||||||
|
dim_t n_p_pad = bli_align_dim_to_mult( n_p, bmult_n_def );
|
||||||
|
|
||||||
|
// Save the padded dimensions into the packed object. It is important
|
||||||
|
// to save these dimensions since they represent the actual dimensions
|
||||||
|
// of the zero-padded matrix.
|
||||||
|
bli_obj_set_padded_dims( m_p_pad, n_p_pad, p );
|
||||||
|
|
||||||
|
// The "panel stride" of a micropanel packed object is interpreted as
|
||||||
|
// the distance between the (0,0) element of panel k and the (0,0)
|
||||||
|
// element of panel k+1. We use the padded width computed above to
|
||||||
|
// allow for zero-padding (if necessary/desired) along the far end
|
||||||
|
// of each micropanel (ie: the right edge of the matrix). Zero-padding
|
||||||
|
// can also occur along the long edge of the last micropanel if the m
|
||||||
|
// dimension of the matrix is not a whole multiple of MR.
|
||||||
|
inc_t ps_p = bmult_m_pack * n_p_pad;
|
||||||
|
|
||||||
|
/* Compute the total number of iterations we'll need. */
|
||||||
|
dim_t n_iter = m_p_pad / bmult_m_def;
|
||||||
|
|
||||||
|
// Store the strides and panel dimension in P.
|
||||||
|
bli_obj_set_strides( 1, bmult_m_pack, p );
|
||||||
|
bli_obj_set_imag_stride( 1, p );
|
||||||
|
bli_obj_set_panel_dim( bmult_m_def, p );
|
||||||
|
bli_obj_set_panel_stride( ps_p, p );
|
||||||
|
bli_obj_set_panel_length( bmult_m_def, p );
|
||||||
|
bli_obj_set_panel_width( n_p, p );
|
||||||
|
|
||||||
|
// Compute the size of the packed buffer.
|
||||||
|
siz_t size_p = ps_p * n_iter * dt_size;
|
||||||
|
if ( size_p == 0 ) return;
|
||||||
|
|
||||||
|
// Update the buffer address in p to point to the buffer associated
|
||||||
|
// with the mem_t entry acquired from the memory broker (now cached in
|
||||||
|
// the control tree node).
|
||||||
|
char* p_cast = (char*)bli_packm_alloc( size_p, rntm, cntl, thread );
|
||||||
|
bli_obj_set_buffer( p_cast, p );
|
||||||
|
|
||||||
|
char* a_cast = (char*)bli_obj_buffer_at_off( a );
|
||||||
|
inc_t inca = bli_obj_row_stride( a );
|
||||||
|
inc_t lda = bli_obj_col_stride( a );
|
||||||
|
dim_t panel_len_off = bli_obj_col_off( a );
|
||||||
|
conj_t conja = bli_obj_conj_status( a );
|
||||||
|
|
||||||
|
auto params = (packm_diag_params_t*)bli_obj_pack_params( a );
|
||||||
|
char* d_cast = (char*)params->d;
|
||||||
|
inc_t incd = params->incd;
|
||||||
|
|
||||||
|
obj_t kappa_local;
|
||||||
|
char* kappa_cast = (char*)bli_packm_scalar( &kappa_local, p );
|
||||||
|
|
||||||
|
auto packm_ker_cast = packm_diag_ukrs[ dt ];
|
||||||
|
|
||||||
|
/* Query the number of threads and thread ids from the current thread's
|
||||||
|
packm thrinfo_t node. */
|
||||||
|
const dim_t nt = bli_thread_n_way( thread );
|
||||||
|
const dim_t tid = bli_thread_work_id( thread );
|
||||||
|
|
||||||
|
/* Determine the thread range and increment using the current thread's
|
||||||
|
packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir()
|
||||||
|
will depend on whether slab or round-robin partitioning was requested
|
||||||
|
at configure-time. */
|
||||||
|
dim_t it_start, it_end, it_inc;
|
||||||
|
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc );
|
||||||
|
|
||||||
|
/* Iterate over every logical micropanel in the source matrix. */
|
||||||
|
for ( dim_t it = 0; it < n_iter; it += 1 )
|
||||||
|
{
|
||||||
|
dim_t panel_dim_i = bli_min( bmult_m_def, m_p - it*bmult_m_def );
|
||||||
|
|
||||||
|
char* d_begin = d_cast + panel_len_off*incd*dt_size;
|
||||||
|
char* a_begin = a_cast + it* bmult_m_def*inca*dt_size;
|
||||||
|
char* p_begin = p_cast + it* ps_p*dt_size;
|
||||||
|
|
||||||
|
if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) )
|
||||||
|
{
|
||||||
|
packm_ker_cast( conja,
|
||||||
|
panel_dim_i,
|
||||||
|
n_p,
|
||||||
|
bmult_m_def,
|
||||||
|
n_p_pad,
|
||||||
|
kappa_cast,
|
||||||
|
d_begin, incd,
|
||||||
|
a_begin, inca, lda,
|
||||||
|
p_begin, bmult_m_pack );
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Modify the object A to include information about the diagonal D,
|
||||||
|
* and imbue it with special function pointers which will take care
|
||||||
|
* of the actual work of forming (D * A^T)
|
||||||
|
*/
|
||||||
|
void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a )
|
||||||
|
{
|
||||||
|
// Assumes D is a column vector
|
||||||
|
new (params) packm_diag_params_t
|
||||||
|
(
|
||||||
|
bli_obj_buffer_at_off( d ),
|
||||||
|
bli_obj_row_stride( d )
|
||||||
|
);
|
||||||
|
|
||||||
|
// Set the custom pack function.
|
||||||
|
bli_obj_set_pack_fn( packm_diag, a );
|
||||||
|
|
||||||
|
// Attach the parameters to the A object.
|
||||||
|
bli_obj_set_pack_params( params, a );
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Implements C := alpha * A * D * A^T + beta * C
|
||||||
|
*
|
||||||
|
* where D is a diagonal matrix with elements taken from the "d" vector.
|
||||||
|
*/
|
||||||
|
void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c )
|
||||||
|
{
|
||||||
|
obj_t ad; // this is (D * A^T)
|
||||||
|
packm_diag_params_t params;
|
||||||
|
|
||||||
|
bli_obj_alias_to( a, &ad );
|
||||||
|
bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T
|
||||||
|
attach_diagonal_factor( ¶ms, d, &ad );
|
||||||
|
|
||||||
|
// Does C := alpha * A * B + beta * C using B = (D + A^T)
|
||||||
|
bli_gemmt( alpha, a, &ad, beta, c );
|
||||||
|
}
|
||||||
|
|
||||||
|
int main()
|
||||||
|
{
|
||||||
|
obj_t a;
|
||||||
|
obj_t d;
|
||||||
|
obj_t c;
|
||||||
|
obj_t c_copy;
|
||||||
|
obj_t norm;
|
||||||
|
|
||||||
|
auto m = 10;
|
||||||
|
auto k = 10;
|
||||||
|
|
||||||
|
for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ )
|
||||||
|
for ( int upper = 0; upper <= 1; upper++ )
|
||||||
|
for ( int transa = 0; transa <= 1; transa++ )
|
||||||
|
for ( int transc = 0; transc <= 1; transc++ )
|
||||||
|
{
|
||||||
|
auto dt = ( num_t )dt_;
|
||||||
|
auto uplo = upper ? BLIS_UPPER : BLIS_LOWER;
|
||||||
|
|
||||||
|
bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a );
|
||||||
|
bli_obj_create( dt, k, 1, 1, 1, &d );
|
||||||
|
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c );
|
||||||
|
bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy );
|
||||||
|
bli_obj_set_struc( BLIS_SYMMETRIC , &c );
|
||||||
|
bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy );
|
||||||
|
bli_obj_set_uplo( uplo , &c );
|
||||||
|
bli_obj_set_uplo( uplo , &c_copy );
|
||||||
|
bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm );
|
||||||
|
|
||||||
|
bli_randm( &a );
|
||||||
|
bli_randm( &d );
|
||||||
|
bli_randm( &c );
|
||||||
|
bli_copym( &c, &c_copy );
|
||||||
|
|
||||||
|
syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c );
|
||||||
|
syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy );
|
||||||
|
|
||||||
|
bli_subm( &c_copy, &c );
|
||||||
|
bli_normfm( &c, &norm );
|
||||||
|
|
||||||
|
double normr, normi;
|
||||||
|
bli_getsc( &norm, &normr, &normi );
|
||||||
|
|
||||||
|
printf("dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n",
|
||||||
|
dt, upper, transa, transc, normr);
|
||||||
|
|
||||||
|
bli_obj_free( &a );
|
||||||
|
bli_obj_free( &d );
|
||||||
|
bli_obj_free( &c );
|
||||||
|
bli_obj_free( &c_copy );
|
||||||
|
bli_obj_free( &norm );
|
||||||
|
}
|
||||||
|
}
|
||||||
102
test/syrk_diagonal/syrk_diagonal_ref.cxx
Normal file
102
test/syrk_diagonal/syrk_diagonal_ref.cxx
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
#include "syrk_diagonal_ref.h"
|
||||||
|
#include "complex_math.hpp"
|
||||||
|
|
||||||
|
typedef void (*syrk_diag_ref_vft)
|
||||||
|
(
|
||||||
|
uplo_t uplo,
|
||||||
|
dim_t m,
|
||||||
|
dim_t k,
|
||||||
|
void* alpha,
|
||||||
|
void* a, inc_t rs_a, inc_t cs_a,
|
||||||
|
void* d, inc_t incd,
|
||||||
|
void* beta,
|
||||||
|
void* c, inc_t rs_c, inc_t cs_c
|
||||||
|
);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void syrk_diag_ref
|
||||||
|
(
|
||||||
|
uplo_t uplo,
|
||||||
|
dim_t m,
|
||||||
|
dim_t k,
|
||||||
|
void* alpha,
|
||||||
|
void* a, inc_t rs_a, inc_t cs_a,
|
||||||
|
void* d, inc_t incd,
|
||||||
|
void* beta,
|
||||||
|
void* c, inc_t rs_c, inc_t cs_c
|
||||||
|
)
|
||||||
|
{
|
||||||
|
auto alpha_cast = *( T* )alpha;
|
||||||
|
auto beta_cast = *( T* )beta;
|
||||||
|
auto a_cast = ( T* )a;
|
||||||
|
auto d_cast = ( T* )d;
|
||||||
|
auto c_cast = ( T* )c;
|
||||||
|
|
||||||
|
for ( dim_t i = 0; i < m; i++ )
|
||||||
|
{
|
||||||
|
dim_t j_min = uplo == BLIS_UPPER ? i : 0;
|
||||||
|
dim_t j_max = uplo == BLIS_UPPER ? m : i+1;
|
||||||
|
|
||||||
|
for ( dim_t j = j_min; j < j_max; j++ )
|
||||||
|
{
|
||||||
|
auto ada = convert<T>(0.0);
|
||||||
|
|
||||||
|
for ( dim_t p = 0; p < k; p++ )
|
||||||
|
{
|
||||||
|
ada += a_cast[ i*rs_a + p*cs_a ] *
|
||||||
|
d_cast[ p*incd ] *
|
||||||
|
a_cast[ j*rs_a + p*cs_a ];
|
||||||
|
}
|
||||||
|
|
||||||
|
if ( beta_cast == convert<T>(0.0) )
|
||||||
|
{
|
||||||
|
c_cast[ i*rs_c + j*cs_c ] = alpha_cast * ada;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
c_cast[ i*rs_c + j*cs_c ] = alpha_cast * ada +
|
||||||
|
beta_cast * c_cast[ i*rs_c + j*cs_c ];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef GENTFUNC
|
||||||
|
#define GENTFUNC(ctype,ch,op) \
|
||||||
|
static auto PASTEMAC(ch,op) = &syrk_diag_ref<ctype>;
|
||||||
|
|
||||||
|
INSERT_GENTFUNC_BASIC0(syrk_diag_ref);
|
||||||
|
|
||||||
|
static syrk_diag_ref_vft GENARRAY( syrk_diag_ref_impl, syrk_diag_ref );
|
||||||
|
|
||||||
|
void syrk_diag_ref( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c )
|
||||||
|
{
|
||||||
|
num_t dt = bli_obj_dt( a );
|
||||||
|
|
||||||
|
dim_t m = bli_obj_length_after_trans( a );
|
||||||
|
dim_t k = bli_obj_width_after_trans( a );
|
||||||
|
|
||||||
|
inc_t rs_a = bli_obj_row_stride( a );
|
||||||
|
inc_t cs_a = bli_obj_col_stride( a );
|
||||||
|
inc_t rs_c = bli_obj_row_stride( c );
|
||||||
|
inc_t cs_c = bli_obj_col_stride( c );
|
||||||
|
inc_t incd = bli_obj_row_stride( d );
|
||||||
|
|
||||||
|
if ( bli_obj_has_trans( a ) )
|
||||||
|
bli_swap_incs( &rs_a, &cs_a );
|
||||||
|
|
||||||
|
if ( bli_obj_has_trans( c ) )
|
||||||
|
bli_swap_incs( &rs_c, &cs_c );
|
||||||
|
|
||||||
|
syrk_diag_ref_impl[ dt ]
|
||||||
|
(
|
||||||
|
bli_obj_uplo( c ),
|
||||||
|
m, k,
|
||||||
|
bli_obj_buffer_for_1x1( dt, alpha ),
|
||||||
|
bli_obj_buffer_at_off( a ), rs_a, cs_a,
|
||||||
|
bli_obj_buffer_at_off( d ), incd,
|
||||||
|
bli_obj_buffer_for_1x1( dt, beta ),
|
||||||
|
bli_obj_buffer_at_off( c ), rs_c, cs_c
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
8
test/syrk_diagonal/syrk_diagonal_ref.h
Normal file
8
test/syrk_diagonal/syrk_diagonal_ref.h
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
#include "blis.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
#include "complex_math.hpp"
|
||||||
|
extern "C"
|
||||||
|
#endif
|
||||||
|
void syrk_diag_ref( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c );
|
||||||
|
|
||||||
267
test/tensor_contraction/complex_math.hpp
Normal file
267
test/tensor_contraction/complex_math.hpp
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
#include <cmath>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "blis.h"
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct is_complex : std::false_type {};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct is_complex<scomplex> : std::true_type {};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct is_complex<dcomplex> : std::true_type {};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct is_real : std::integral_constant<bool,!is_complex<T>::value> {};
|
||||||
|
|
||||||
|
template <typename T> struct make_complex;
|
||||||
|
|
||||||
|
template <> struct make_complex<float > { using type = scomplex; };
|
||||||
|
template <> struct make_complex<double > { using type = dcomplex; };
|
||||||
|
template <> struct make_complex<scomplex> { using type = scomplex; };
|
||||||
|
template <> struct make_complex<dcomplex> { using type = dcomplex; };
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using make_complex_t = typename make_complex<T>::type;
|
||||||
|
|
||||||
|
template <typename T> struct make_real;
|
||||||
|
|
||||||
|
template <> struct make_real<float > { using type = float; };
|
||||||
|
template <> struct make_real<double > { using type = double; };
|
||||||
|
template <> struct make_real<scomplex> { using type = float; };
|
||||||
|
template <> struct make_real<dcomplex> { using type = double; };
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using make_real_t = typename make_real<T>::type;
|
||||||
|
|
||||||
|
template <typename T, bool Cond>
|
||||||
|
struct make_complex_if : std::conditional<Cond,make_complex_t<T>,make_real_t<T>> {};
|
||||||
|
|
||||||
|
template <typename T, bool Cond>
|
||||||
|
using make_complex_if_t = typename make_complex_if<T,Cond>::type;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct real_imag_part
|
||||||
|
{
|
||||||
|
real_imag_part& operator=(T) { return *this; }
|
||||||
|
|
||||||
|
operator T() const { return T(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<std::is_arithmetic<typename std::remove_cv<T>::type>::value,T&> real(T& x) { return x; }
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<std::is_arithmetic<T>::value,real_imag_part<T>> imag(T x) { return {}; }
|
||||||
|
|
||||||
|
inline float& real(scomplex& x) { return x.real; }
|
||||||
|
|
||||||
|
inline float& imag(scomplex& x) { return x.imag; }
|
||||||
|
|
||||||
|
inline double& real(dcomplex& x) { return x.real; }
|
||||||
|
|
||||||
|
inline double& imag(dcomplex& x) { return x.imag; }
|
||||||
|
|
||||||
|
inline const float& real(const scomplex& x) { return x.real; }
|
||||||
|
|
||||||
|
inline const float& imag(const scomplex& x) { return x.imag; }
|
||||||
|
|
||||||
|
inline const double& real(const dcomplex& x) { return x.real; }
|
||||||
|
|
||||||
|
inline const double& imag(const dcomplex& x) { return x.imag; }
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<is_real<T>::value,T> conj(T x) { return x; }
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::enable_if_t<is_complex<T>::value,T> conj(const T& x) { return {x.real, -x.imag}; }
|
||||||
|
|
||||||
|
template <typename T, typename U, typename=void>
|
||||||
|
struct convert_impl;
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
struct convert_impl<T, U, std::enable_if_t<is_real<T>::value && is_real<U>::value>>
|
||||||
|
{
|
||||||
|
void operator()(T x, U& y) const { y = x; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
struct convert_impl<T, U, std::enable_if_t<is_real<T>::value && is_complex<U>::value>>
|
||||||
|
{
|
||||||
|
void operator()(T x, U& y) const { y.real = x; y.imag = 0; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
struct convert_impl<T, U, std::enable_if_t<is_complex<T>::value && is_real<U>::value>>
|
||||||
|
{
|
||||||
|
void operator()(T x, U& y) const { y = x.real; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
struct convert_impl<T, U, std::enable_if_t<is_complex<T>::value && is_complex<U>::value>>
|
||||||
|
{
|
||||||
|
void operator()(T x, U& y) const { y.real = x.real; y.imag = x.imag; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename U, typename T>
|
||||||
|
U convert(T x)
|
||||||
|
{
|
||||||
|
U y;
|
||||||
|
convert_impl<T,U>{}(x,y);
|
||||||
|
return y;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U, typename T>
|
||||||
|
auto convert_prec(T x) -> make_complex_if_t<U,is_complex<T>::value>
|
||||||
|
{
|
||||||
|
return convert<make_complex_if_t<U,is_complex<T>::value>>(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define COMPLEX_MATH_OPS(rtype, ctype) \
|
||||||
|
\
|
||||||
|
inline bool operator==(rtype x, ctype y) \
|
||||||
|
{ \
|
||||||
|
return x == y.real && y.imag == 0; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline bool operator==(ctype x, rtype y) \
|
||||||
|
{ \
|
||||||
|
return y == x.real && x.imag == 0; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline bool operator==(ctype x, ctype y) \
|
||||||
|
{ \
|
||||||
|
return x.real == y.real && \
|
||||||
|
x.imag == y.imag; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator-(ctype x) \
|
||||||
|
{ \
|
||||||
|
return {-x.real, -x.imag}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator+(rtype x, ctype y) \
|
||||||
|
{ \
|
||||||
|
return {x+y.real, y.imag}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator+(ctype x, rtype y) \
|
||||||
|
{ \
|
||||||
|
return {y+x.real, x.imag}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator+(ctype x, ctype y) \
|
||||||
|
{ \
|
||||||
|
return {x.real+y.real, x.imag+y.imag}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator-(rtype x, ctype y) \
|
||||||
|
{ \
|
||||||
|
return {x-y.real, -y.imag}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator-(ctype x, rtype y) \
|
||||||
|
{ \
|
||||||
|
return {x.real-y, x.imag}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator-(ctype x, ctype y) \
|
||||||
|
{ \
|
||||||
|
return {x.real-y.real, x.imag-y.imag}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator*(rtype x, ctype y) \
|
||||||
|
{ \
|
||||||
|
return {x*y.real, x*y.imag}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator*(ctype x, rtype y) \
|
||||||
|
{ \
|
||||||
|
return {y*x.real, y*x.imag}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator*(ctype x, ctype y) \
|
||||||
|
{ \
|
||||||
|
return {x.real*y.real - x.imag*y.imag, \
|
||||||
|
x.real*y.imag + x.imag*y.real}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator/(rtype x, ctype y) \
|
||||||
|
{ \
|
||||||
|
auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \
|
||||||
|
auto n = std::ilogb(scale); \
|
||||||
|
auto yrs = std::scalbn(y.real, -n); \
|
||||||
|
auto yis = std::scalbn(y.imag, -n); \
|
||||||
|
auto denom = y.real*yrs + y.imag*yis; \
|
||||||
|
return {x*yrs/denom, -x*yis/denom}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator/(ctype x, rtype y) \
|
||||||
|
{ \
|
||||||
|
return {x.real/y, x.imag/y}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype operator/(ctype x, ctype y) \
|
||||||
|
{ \
|
||||||
|
auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \
|
||||||
|
auto n = std::ilogb(scale); \
|
||||||
|
auto yrs = std::scalbn(y.real, -n); \
|
||||||
|
auto yis = std::scalbn(y.imag, -n); \
|
||||||
|
auto denom = y.real*yrs + y.imag*yis; \
|
||||||
|
return {(x.real*yrs + x.imag*yis)/denom, \
|
||||||
|
(x.imag*yrs - x.real*yis)/denom}; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype& operator+=(ctype& x, rtype y) \
|
||||||
|
{ \
|
||||||
|
x.real += y; \
|
||||||
|
return x; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype& operator+=(ctype& x, ctype y) \
|
||||||
|
{ \
|
||||||
|
x.real += y.real; x.imag += y.imag; \
|
||||||
|
return x; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype& operator-=(ctype& x, rtype y) \
|
||||||
|
{ \
|
||||||
|
x.real -= y; \
|
||||||
|
return x; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype& operator-=(ctype& x, ctype y) \
|
||||||
|
{ \
|
||||||
|
x.real -= y.real; x.imag -= y.imag; \
|
||||||
|
return x; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype& operator*=(ctype& x, rtype y) \
|
||||||
|
{ \
|
||||||
|
x.real *= y; x.imag *= y; \
|
||||||
|
return x; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype& operator*=(ctype& x, ctype y) \
|
||||||
|
{ \
|
||||||
|
x = x * y; \
|
||||||
|
return x; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype& operator/=(ctype& x, rtype y) \
|
||||||
|
{ \
|
||||||
|
x.real /= y; x.imag /= y; \
|
||||||
|
return x; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
inline ctype& operator/=(ctype& x, ctype y) \
|
||||||
|
{ \
|
||||||
|
x = x / y; \
|
||||||
|
return x; \
|
||||||
|
}
|
||||||
|
|
||||||
|
COMPLEX_MATH_OPS(float, scomplex);
|
||||||
|
COMPLEX_MATH_OPS(double, dcomplex);
|
||||||
|
|
||||||
988
test/tensor_contraction/tcontract_example.cxx
Normal file
988
test/tensor_contraction/tcontract_example.cxx
Normal file
@@ -0,0 +1,988 @@
|
|||||||
|
|
||||||
|
#include "tcontract_ref.hpp"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
static constexpr dim_t BS_K = 8;
|
||||||
|
|
||||||
|
struct packm_tensor_params_t
|
||||||
|
{
|
||||||
|
gint_t ndim_m, ndim_n;
|
||||||
|
const dim_t *len_m, *len_n;
|
||||||
|
const inc_t *stride_m, *stride_n;
|
||||||
|
|
||||||
|
packm_tensor_params_t() {}
|
||||||
|
|
||||||
|
packm_tensor_params_t( gint_t ndim_m, const dim_t* len_m, const inc_t* stride_m,
|
||||||
|
gint_t ndim_n, const dim_t* len_n, const inc_t* stride_n )
|
||||||
|
: ndim_m(ndim_m), ndim_n(ndim_n),
|
||||||
|
len_m(len_m), len_n(len_n),
|
||||||
|
stride_m(stride_m), stride_n(stride_n) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
using gemm_tensor_params_t = packm_tensor_params_t;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void packm_ckx_nb
|
||||||
|
(
|
||||||
|
bool conja,
|
||||||
|
dim_t panel_dim,
|
||||||
|
dim_t panel_len,
|
||||||
|
dim_t panel_dim_max,
|
||||||
|
dim_t panel_len_max,
|
||||||
|
void* kappa,
|
||||||
|
void* a, inc_t inca, inc_t* bsa, inc_t* scata,
|
||||||
|
void* p, inc_t ldp
|
||||||
|
)
|
||||||
|
{
|
||||||
|
T* restrict a_cast = ( T* )a;
|
||||||
|
T* restrict p_cast = ( T* )p;
|
||||||
|
auto kappa_cast = *( T* )kappa;
|
||||||
|
|
||||||
|
if ( conja )
|
||||||
|
{
|
||||||
|
for ( auto j0 = 0; j0 < panel_len; j0 += BS_K, bsa += BS_K, scata += BS_K )
|
||||||
|
{
|
||||||
|
auto lda = *bsa;
|
||||||
|
auto panel_len_j = std::min<dim_t>( panel_len-j0, BS_K );
|
||||||
|
|
||||||
|
if ( lda )
|
||||||
|
{
|
||||||
|
T* restrict aj = a_cast + *scata;
|
||||||
|
|
||||||
|
for ( auto j = 0; j < panel_len_j; j++ )
|
||||||
|
{
|
||||||
|
for ( auto i = 0; i < panel_dim; i++ )
|
||||||
|
p_cast[ i ] = kappa_cast * conj( aj[ i*inca + j*lda ] );
|
||||||
|
|
||||||
|
for ( auto i = panel_dim; i < panel_dim_max; i++ )
|
||||||
|
p_cast[ i ] = convert<T>(0.0);
|
||||||
|
|
||||||
|
p_cast += ldp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for ( auto j = 0; j < panel_len_j; j++)
|
||||||
|
{
|
||||||
|
for ( auto i = 0; i < panel_dim; i++)
|
||||||
|
p_cast[ i ] = kappa_cast * conj( a_cast[ i*inca + scata[j] ] );
|
||||||
|
|
||||||
|
for ( auto i = panel_dim; i < panel_dim_max; i++)
|
||||||
|
p_cast[ i ] = convert<T>(0.0);
|
||||||
|
|
||||||
|
p_cast += ldp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for ( auto j0 = 0; j0 < panel_len; j0 += BS_K, bsa += BS_K, scata += BS_K )
|
||||||
|
{
|
||||||
|
auto lda = *bsa;
|
||||||
|
auto panel_len_j = std::min<dim_t>( panel_len-j0, BS_K );
|
||||||
|
|
||||||
|
if ( lda )
|
||||||
|
{
|
||||||
|
T* restrict aj = a_cast + *scata;
|
||||||
|
|
||||||
|
for ( auto j = 0; j < panel_len_j; j++ )
|
||||||
|
{
|
||||||
|
for ( auto i = 0; i < panel_dim; i++ )
|
||||||
|
p_cast[ i ] = kappa_cast * aj[ i*inca + j*lda ];
|
||||||
|
|
||||||
|
for ( auto i = panel_dim; i < panel_dim_max; i++ )
|
||||||
|
p_cast[ i ] = convert<T>(0.0);
|
||||||
|
|
||||||
|
p_cast += ldp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for ( auto j = 0; j < panel_len_j; j++ )
|
||||||
|
{
|
||||||
|
for ( auto i = 0; i < panel_dim; i++ )
|
||||||
|
p_cast[ i ] = kappa_cast * a_cast[ i*inca + scata[j] ];
|
||||||
|
|
||||||
|
for ( auto i = panel_dim; i < panel_dim_max; i++ )
|
||||||
|
p_cast[ i ] = convert<T>(0.0);
|
||||||
|
|
||||||
|
p_cast += ldp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for ( auto j = panel_len; j < panel_len_max; j++)
|
||||||
|
{
|
||||||
|
for ( auto i = 0; i < panel_dim_max; i++)
|
||||||
|
p_cast[ i ] = convert<T>(0.0);
|
||||||
|
|
||||||
|
p_cast += ldp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void packm_ckx_ss
|
||||||
|
(
|
||||||
|
bool conja,
|
||||||
|
dim_t panel_dim,
|
||||||
|
dim_t panel_len,
|
||||||
|
dim_t panel_dim_max,
|
||||||
|
dim_t panel_len_max,
|
||||||
|
void* kappa,
|
||||||
|
void* a, inc_t* inca, inc_t* scata,
|
||||||
|
void* p, inc_t ldp
|
||||||
|
)
|
||||||
|
{
|
||||||
|
T* restrict a_cast = ( T* )a;
|
||||||
|
T* restrict p_cast = ( T* )p;
|
||||||
|
auto kappa_cast = *( T* )kappa;
|
||||||
|
|
||||||
|
if ( conja )
|
||||||
|
{
|
||||||
|
for (dim_t j = 0;j < panel_len;j++)
|
||||||
|
{
|
||||||
|
for (dim_t i = 0;i < panel_dim;i++)
|
||||||
|
p_cast[ i ] = kappa_cast * conj( a_cast[ inca[i] + scata[j] ] );
|
||||||
|
|
||||||
|
for (dim_t i = panel_dim;i < panel_dim_max;i++)
|
||||||
|
p_cast[ i ] = convert<T>(0.0);
|
||||||
|
|
||||||
|
p_cast += ldp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for (dim_t j = 0;j < panel_len;j++)
|
||||||
|
{
|
||||||
|
for (dim_t i = 0;i < panel_dim;i++)
|
||||||
|
p_cast[ i ] = kappa_cast * a_cast[ inca[i] + scata[j] ];
|
||||||
|
|
||||||
|
for (dim_t i = panel_dim;i < panel_dim_max;i++)
|
||||||
|
p_cast[ i ] = convert<T>(0.0);
|
||||||
|
|
||||||
|
p_cast += ldp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (dim_t j = panel_len;j < panel_len_max;j++)
|
||||||
|
{
|
||||||
|
for (dim_t i = 0;i < panel_dim_max;i++)
|
||||||
|
p_cast[ i ] = convert<T>(0.0);
|
||||||
|
|
||||||
|
p_cast += ldp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef GENTFUNC
|
||||||
|
#define GENTFUNC(ctype,ch,op) \
|
||||||
|
static auto PASTEMAC(ch,op) = &packm_ckx_nb<ctype>;
|
||||||
|
|
||||||
|
INSERT_GENTFUNC_BASIC0(packm_ckx_nb);
|
||||||
|
|
||||||
|
#undef GENTFUNC
|
||||||
|
#define GENTFUNC(ctype,ch,op) \
|
||||||
|
static auto PASTEMAC(ch,op) = &packm_ckx_ss<ctype>;
|
||||||
|
|
||||||
|
INSERT_GENTFUNC_BASIC0(packm_ckx_ss);
|
||||||
|
|
||||||
|
static decltype(&packm_ckx_nb<void>) GENARRAY( packm_ckx_nb_ukrs, packm_ckx_nb );
|
||||||
|
static decltype(&packm_ckx_ss<void>) GENARRAY( packm_ckx_ss_ukrs, packm_ckx_ss );
|
||||||
|
|
||||||
|
static void fill_scatter
|
||||||
|
(
|
||||||
|
gint_t ndim,
|
||||||
|
const dim_t* restrict len,
|
||||||
|
const inc_t* restrict stride,
|
||||||
|
dim_t BS,
|
||||||
|
inc_t off,
|
||||||
|
dim_t size,
|
||||||
|
inc_t* restrict scat,
|
||||||
|
inc_t* restrict bs
|
||||||
|
)
|
||||||
|
{
|
||||||
|
if ( size == 0 ) return;
|
||||||
|
|
||||||
|
if ( ndim == 0 )
|
||||||
|
{
|
||||||
|
*scat = 0;
|
||||||
|
*bs = 0;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ( ndim == 1 )
|
||||||
|
{
|
||||||
|
auto l = *len;
|
||||||
|
auto s = *stride;
|
||||||
|
for ( auto i = 0; i < l; i++ )
|
||||||
|
{
|
||||||
|
scat[i] = i*s;
|
||||||
|
bs[i] = s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dim_t tot_len = 1;
|
||||||
|
for ( auto i = 0; i < ndim; i++ )
|
||||||
|
tot_len *= len[i];
|
||||||
|
|
||||||
|
assert(off >= 0);
|
||||||
|
assert(size >= 0);
|
||||||
|
assert(off+size <= tot_len);
|
||||||
|
|
||||||
|
auto len0 = len[0];
|
||||||
|
auto stride0 = stride[0];
|
||||||
|
auto off0 = off % len0;
|
||||||
|
auto off1 = off / len0;
|
||||||
|
auto size1 = ( size + off0 + len0 - 1) / len0;
|
||||||
|
|
||||||
|
inc_t pos1 = 0;
|
||||||
|
inc_t idx = 0;
|
||||||
|
for_each( ndim-1, len+1, off1, size1, pos1, stride+1,
|
||||||
|
[&]
|
||||||
|
{
|
||||||
|
auto pos = pos1 + off0 * stride0;
|
||||||
|
auto len_i = std::min( len0-off0, size-idx );
|
||||||
|
for ( auto i = 0; i < len_i; i++ )
|
||||||
|
{
|
||||||
|
scat[idx++] = pos;
|
||||||
|
pos += stride0;
|
||||||
|
}
|
||||||
|
off0 = 0;
|
||||||
|
});
|
||||||
|
assert(idx == size);
|
||||||
|
|
||||||
|
for ( idx = 0; idx < size; idx += BS )
|
||||||
|
{
|
||||||
|
auto len_i = std::min( BS, size-idx );
|
||||||
|
auto s = stride0;
|
||||||
|
|
||||||
|
for ( auto i = idx; i < idx+len_i-1; i++)
|
||||||
|
{
|
||||||
|
if (scat[i+1]-scat[i] != s)
|
||||||
|
{
|
||||||
|
s = 0;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bs[idx] = s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void packm_tensor
|
||||||
|
(
|
||||||
|
obj_t* a,
|
||||||
|
obj_t* p,
|
||||||
|
cntx_t* cntx,
|
||||||
|
rntm_t* rntm,
|
||||||
|
cntl_t* cntl,
|
||||||
|
thrinfo_t* thread
|
||||||
|
)
|
||||||
|
{
|
||||||
|
// We begin by copying the fields of A.
|
||||||
|
bli_obj_alias_to( a, p );
|
||||||
|
|
||||||
|
// Get information about data types.
|
||||||
|
auto dt = bli_obj_dt( a );
|
||||||
|
auto dt_tar = bli_obj_target_dt( a );
|
||||||
|
auto dt_scalar = bli_obj_scalar_dt( a );
|
||||||
|
auto dt_size = bli_dt_size( dt );
|
||||||
|
|
||||||
|
if ( dt_scalar != dt || dt_tar != dt )
|
||||||
|
bli_abort();
|
||||||
|
|
||||||
|
// Extract various fields from the control tree.
|
||||||
|
auto bmult_id_m = bli_cntl_packm_params_bmid_m( cntl );
|
||||||
|
auto bmult_id_n = bli_cntl_packm_params_bmid_n( cntl );
|
||||||
|
auto schema = bli_cntl_packm_params_pack_schema( cntl );
|
||||||
|
auto bmult_m_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_m, cntx );
|
||||||
|
auto bmult_m_pack = bli_cntx_get_blksz_max_dt( dt_tar, bmult_id_m, cntx );
|
||||||
|
auto bmult_n_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_n, cntx );
|
||||||
|
|
||||||
|
if ( schema != BLIS_PACKED_ROW_PANELS &&
|
||||||
|
schema != BLIS_PACKED_COL_PANELS )
|
||||||
|
bli_abort();
|
||||||
|
|
||||||
|
// Store the pack schema to the object.
|
||||||
|
bli_obj_set_pack_schema( schema, p );
|
||||||
|
|
||||||
|
// Clear the conjugation field from the object since matrix packing
|
||||||
|
// in BLIS is deemed to take care of all conjugation necessary.
|
||||||
|
bli_obj_set_conj( BLIS_NO_CONJUGATE, p );
|
||||||
|
|
||||||
|
// If we are packing micropanels, mark P as dense.
|
||||||
|
bli_obj_set_uplo( BLIS_DENSE, p );
|
||||||
|
|
||||||
|
// Reset the view offsets to (0,0).
|
||||||
|
bli_obj_set_offs( 0, 0, p );
|
||||||
|
|
||||||
|
// Compute the dimensions padded by the dimension multiples. These
|
||||||
|
// dimensions will be the dimensions of the packed matrices, including
|
||||||
|
// zero-padding, and will be used by the macro- and micro-kernels.
|
||||||
|
// We compute them by starting with the effective dimensions of A (now
|
||||||
|
// in P) and aligning them to the dimension multiples (typically equal
|
||||||
|
// to register blocksizes). This does waste a little bit of space for
|
||||||
|
// level-2 operations, but that's okay with us.
|
||||||
|
auto m_p = bli_obj_length( p );
|
||||||
|
auto n_p = bli_obj_width( p );
|
||||||
|
auto m_p_pad = bli_align_dim_to_mult( m_p, bmult_m_def );
|
||||||
|
auto n_p_pad = bli_align_dim_to_mult( n_p, bmult_n_def );
|
||||||
|
|
||||||
|
// Save the padded dimensions into the packed object. It is important
|
||||||
|
// to save these dimensions since they represent the actual dimensions
|
||||||
|
// of the zero-padded matrix.
|
||||||
|
bli_obj_set_padded_dims( m_p_pad, n_p_pad, p );
|
||||||
|
|
||||||
|
// The "panel stride" of a micropanel packed object is interpreted as
|
||||||
|
// the distance between the (0,0) element of panel k and the (0,0)
|
||||||
|
// element of panel k+1. We use the padded width computed above to
|
||||||
|
// allow for zero-padding (if necessary/desired) along the far end
|
||||||
|
// of each micropanel (ie: the right edge of the matrix). Zero-padding
|
||||||
|
// can also occur along the long edge of the last micropanel if the m
|
||||||
|
// dimension of the matrix is not a whole multiple of MR.
|
||||||
|
auto ps_p = bmult_m_pack * n_p_pad;
|
||||||
|
|
||||||
|
/* Compute the total number of iterations we'll need. */
|
||||||
|
auto n_iter = m_p_pad / bmult_m_def;
|
||||||
|
|
||||||
|
// Store the strides and panel dimension in P.
|
||||||
|
bli_obj_set_strides( 1, bmult_m_pack, p );
|
||||||
|
bli_obj_set_imag_stride( 1, p );
|
||||||
|
bli_obj_set_panel_dim( bmult_m_def, p );
|
||||||
|
bli_obj_set_panel_stride( ps_p, p );
|
||||||
|
bli_obj_set_panel_length( bmult_m_def, p );
|
||||||
|
bli_obj_set_panel_width( n_p, p );
|
||||||
|
|
||||||
|
// Compute the size of the packed buffer.
|
||||||
|
auto size_p = ps_p * n_iter * dt_size;
|
||||||
|
if ( size_p == 0 ) return;
|
||||||
|
|
||||||
|
// Compute the size of the scatter and block-scatter vectors to the total.
|
||||||
|
// It is never necessary to add padding for alignment because:
|
||||||
|
// 1) ps_p is always even
|
||||||
|
// 2) dt_size is a power of two >= 4
|
||||||
|
// 3) the alignment of the scatter vectors is at most 8
|
||||||
|
auto scat_size = 2 * (m_p + n_p) * sizeof(inc_t);
|
||||||
|
|
||||||
|
// Update the buffer address in p to point to the buffer associated
|
||||||
|
// with the mem_t entry acquired from the memory broker (now cached in
|
||||||
|
// the control tree node).
|
||||||
|
auto p_cast = (char*)bli_packm_alloc( size_p + scat_size, rntm, cntl, thread );
|
||||||
|
bli_obj_set_buffer( p_cast, p );
|
||||||
|
|
||||||
|
// Get the addresses of the scatter and block-scatter vectors. These are
|
||||||
|
// placed directly after the packed matrix buffer.
|
||||||
|
auto rscat = (inc_t*)(p_cast + size_p);
|
||||||
|
auto rbs = rscat + m_p;
|
||||||
|
auto cscat = rbs + m_p;
|
||||||
|
auto cbs = cscat + n_p;
|
||||||
|
|
||||||
|
auto a_cast = (char*)bli_obj_buffer_at_off( a );
|
||||||
|
auto panel_dim_off = bli_obj_row_off( a );
|
||||||
|
auto panel_len_off = bli_obj_col_off( a );
|
||||||
|
auto conja = bli_obj_conj_status( a );
|
||||||
|
|
||||||
|
auto params = (packm_tensor_params_t*)bli_obj_pack_params( a );
|
||||||
|
auto ndim_m = params->ndim_m;
|
||||||
|
auto ndim_n = params->ndim_n;
|
||||||
|
auto len_m = params->len_m;
|
||||||
|
auto len_n = params->len_n;
|
||||||
|
auto stride_m = params->stride_m;
|
||||||
|
auto stride_n = params->stride_n;
|
||||||
|
|
||||||
|
obj_t kappa_local;
|
||||||
|
auto kappa_cast = (char*)bli_packm_scalar( &kappa_local, p );
|
||||||
|
|
||||||
|
auto packm_nb_ker = packm_ckx_nb_ukrs[ dt ];
|
||||||
|
auto packm_ss_ker = packm_ckx_ss_ukrs[ dt ];
|
||||||
|
|
||||||
|
a_cast -= ( panel_dim_off * stride_m[0] +
|
||||||
|
panel_len_off * stride_n[0] ) * dt_size;
|
||||||
|
|
||||||
|
/* Fill in the scatter and block-scatter vectors. This is done single-threaded for now. */
|
||||||
|
if ( bli_thread_am_ochief( thread ) )
|
||||||
|
{
|
||||||
|
fill_scatter
|
||||||
|
(
|
||||||
|
ndim_m,
|
||||||
|
len_m,
|
||||||
|
stride_m,
|
||||||
|
bmult_m_def,
|
||||||
|
panel_dim_off,
|
||||||
|
m_p,
|
||||||
|
rscat,
|
||||||
|
rbs
|
||||||
|
);
|
||||||
|
|
||||||
|
fill_scatter
|
||||||
|
(
|
||||||
|
ndim_n,
|
||||||
|
len_n,
|
||||||
|
stride_n,
|
||||||
|
BS_K,
|
||||||
|
panel_len_off,
|
||||||
|
n_p,
|
||||||
|
cscat,
|
||||||
|
cbs
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Wait for the scatter vectors to be done. */
|
||||||
|
bli_thread_barrier( thread );
|
||||||
|
|
||||||
|
/* Query the number of threads and thread ids from the current thread's
|
||||||
|
packm thrinfo_t node. */
|
||||||
|
auto nt = bli_thread_n_way( thread );
|
||||||
|
auto tid = bli_thread_work_id( thread );
|
||||||
|
|
||||||
|
/* Determine the thread range and increment using the current thread's
|
||||||
|
packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir()
|
||||||
|
will depend on whether slab or round-robin partitioning was requested
|
||||||
|
at configure-time. */
|
||||||
|
dim_t it_start, it_end, it_inc;
|
||||||
|
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc );
|
||||||
|
|
||||||
|
/* Iterate over every logical micropanel in the source matrix. */
|
||||||
|
for ( auto it = 0; it < n_iter; it += 1 )
|
||||||
|
if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) )
|
||||||
|
{
|
||||||
|
auto panel_dim_i = bli_min( bmult_m_def, m_p - it*bmult_m_def );
|
||||||
|
|
||||||
|
auto p_begin = p_cast + it*ps_p*dt_size;
|
||||||
|
auto inca = rbs[ it*bmult_m_def ];
|
||||||
|
|
||||||
|
if ( inca )
|
||||||
|
{
|
||||||
|
auto a_begin = a_cast + rscat[ it*bmult_m_def ]*dt_size;
|
||||||
|
|
||||||
|
packm_nb_ker( conja,
|
||||||
|
panel_dim_i,
|
||||||
|
n_p,
|
||||||
|
bmult_m_def,
|
||||||
|
n_p_pad,
|
||||||
|
kappa_cast,
|
||||||
|
a_begin, inca, cbs, cscat,
|
||||||
|
p_begin, bmult_m_pack );
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
auto a_begin = a_cast;
|
||||||
|
auto rscat_use = rscat + it*bmult_m_def;
|
||||||
|
|
||||||
|
packm_ss_ker( conja,
|
||||||
|
panel_dim_i,
|
||||||
|
n_p,
|
||||||
|
bmult_m_def,
|
||||||
|
n_p_pad,
|
||||||
|
kappa_cast,
|
||||||
|
a_begin, rscat_use, cscat,
|
||||||
|
p_begin, bmult_m_pack );
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef GENTFUNC
|
||||||
|
#define GENTFUNC(ctype,ch,op) \
|
||||||
|
void PASTEMAC(ch,op) \
|
||||||
|
( \
|
||||||
|
dim_t m, \
|
||||||
|
dim_t n, \
|
||||||
|
void* x, inc_t rs_x, inc_t cs_x, \
|
||||||
|
void* b, \
|
||||||
|
void* y, inc_t* rs_y, inc_t* cs_y \
|
||||||
|
) \
|
||||||
|
{ \
|
||||||
|
ctype* restrict x_cast = (ctype*)x; \
|
||||||
|
ctype b_cast = *(ctype*)b; \
|
||||||
|
ctype* restrict y_cast = (ctype*)y; \
|
||||||
|
\
|
||||||
|
if ( PASTEMAC(ch,eq0)( b_cast ) ) \
|
||||||
|
{ \
|
||||||
|
for ( auto i = 0; i < m; i++ ) \
|
||||||
|
for ( auto j = 0; j < n; j++ ) \
|
||||||
|
PASTEMAC(ch,copys)( x_cast[ i*rs_x + j*cs_x ], y_cast[ rs_y[i] + cs_y[j] ] ); \
|
||||||
|
} \
|
||||||
|
else \
|
||||||
|
{ \
|
||||||
|
for ( auto i = 0; i < m; i++ ) \
|
||||||
|
for ( auto j = 0; j < n; j++ ) \
|
||||||
|
PASTEMAC(ch,xpbys)( x_cast[ i*rs_x + j*cs_x ], b_cast, y_cast[ rs_y[i] + cs_y[j] ] ); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
INSERT_GENTFUNC_BASIC0(scatter_mxn);
|
||||||
|
|
||||||
|
static decltype(&bli_sscatter_mxn) GENARRAY(scatter_mxn, scatter_mxn);
|
||||||
|
|
||||||
|
void gemm_tensor
|
||||||
|
(
|
||||||
|
obj_t* a,
|
||||||
|
obj_t* b,
|
||||||
|
obj_t* c,
|
||||||
|
cntx_t* cntx,
|
||||||
|
rntm_t* rntm,
|
||||||
|
cntl_t* cntl,
|
||||||
|
thrinfo_t* thread
|
||||||
|
)
|
||||||
|
{
|
||||||
|
auto dt = bli_obj_dt( c );
|
||||||
|
auto dt_size = bli_dt_size( dt );
|
||||||
|
|
||||||
|
auto m = bli_obj_length( c );
|
||||||
|
auto n = bli_obj_width( c );
|
||||||
|
auto k = bli_obj_width( a );
|
||||||
|
|
||||||
|
auto a_cast = (char*)bli_obj_buffer_at_off( a );
|
||||||
|
auto pd_a = bli_obj_panel_dim( a );
|
||||||
|
auto ps_a = bli_obj_panel_stride( a );
|
||||||
|
|
||||||
|
auto b_cast = (char*)bli_obj_buffer_at_off( b );
|
||||||
|
auto pd_b = bli_obj_panel_dim( b );
|
||||||
|
auto ps_b = bli_obj_panel_stride( b );
|
||||||
|
|
||||||
|
auto c_cast = (char*)bli_obj_buffer_at_off( c );
|
||||||
|
auto rs_c0 = bli_obj_row_stride( c );
|
||||||
|
auto cs_c0 = bli_obj_col_stride( c );
|
||||||
|
auto off_m = bli_obj_row_off( c );
|
||||||
|
auto off_n = bli_obj_col_off( c );
|
||||||
|
|
||||||
|
auto params = (gemm_tensor_params_t*)bli_obj_ker_params( c );
|
||||||
|
auto ndim_m = params->ndim_m;
|
||||||
|
auto ndim_n = params->ndim_n;
|
||||||
|
auto len_m = params->len_m;
|
||||||
|
auto len_n = params->len_n;
|
||||||
|
auto stride_m = params->stride_m;
|
||||||
|
auto stride_n = params->stride_n;
|
||||||
|
|
||||||
|
if ( rs_c0 != stride_m[0] || cs_c0 != stride_n[0] )
|
||||||
|
{
|
||||||
|
std::swap( ndim_m, ndim_n );
|
||||||
|
std::swap( len_m, len_n );
|
||||||
|
std::swap( stride_m, stride_n );
|
||||||
|
}
|
||||||
|
|
||||||
|
/* If any dimension is zero, return immediately. */
|
||||||
|
if ( bli_zero_dim3( m, n, k ) ) return;
|
||||||
|
|
||||||
|
c_cast -= ( off_m * stride_m[0] +
|
||||||
|
off_n * stride_n[0] ) * dt_size;
|
||||||
|
|
||||||
|
// Detach and multiply the scalars attached to A and B.
|
||||||
|
// NOTE: We know that the internal scalars of A and B are already of the
|
||||||
|
// target datatypes because the necessary typecasting would have already
|
||||||
|
// taken place during bli_packm_init().
|
||||||
|
obj_t scalar_a;
|
||||||
|
obj_t scalar_b;
|
||||||
|
bli_obj_scalar_detach( a, &scalar_a );
|
||||||
|
bli_obj_scalar_detach( b, &scalar_b );
|
||||||
|
bli_mulsc( &scalar_a, &scalar_b );
|
||||||
|
|
||||||
|
// Grab the addresses of the internal scalar buffers for the scalar
|
||||||
|
// merged above and the scalar attached to C.
|
||||||
|
// NOTE: We know that scalar_b is of type dt due to the above code
|
||||||
|
// that casts the scalars of A and B to dt via scalar_a and scalar_b,
|
||||||
|
// and we know that the internal scalar in C is already of the type dt
|
||||||
|
// due to the casting in the implementation of bli_obj_scalar_attach().
|
||||||
|
auto alpha_cast = (char*)bli_obj_internal_scalar_buffer( &scalar_b );
|
||||||
|
auto beta_cast = (char*)bli_obj_internal_scalar_buffer( c );
|
||||||
|
|
||||||
|
/* Alias some constants to simpler names. */
|
||||||
|
auto MR = pd_a;
|
||||||
|
auto NR = pd_b;
|
||||||
|
|
||||||
|
/* Query the context for the micro-kernel address and cast it to its
|
||||||
|
function pointer type. */
|
||||||
|
auto gemm_ukr = (gemm_ukr_vft)bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx );
|
||||||
|
|
||||||
|
/* Temporary C buffer for edge cases. Note that the strides of this
|
||||||
|
temporary buffer are set so that they match the storage of the
|
||||||
|
original C matrix. For example, if C is column-stored, ct will be
|
||||||
|
column-stored as well. */
|
||||||
|
char ct[ BLIS_STACK_BUF_MAX_SIZE ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE)));
|
||||||
|
auto col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx );
|
||||||
|
auto rs_ct = ( col_pref ? 1 : NR );
|
||||||
|
auto cs_ct = ( col_pref ? MR : 1 );
|
||||||
|
auto zero = (char*)bli_obj_buffer_for_const( dt, &BLIS_ZERO );
|
||||||
|
|
||||||
|
/*
|
||||||
|
Assumptions/assertions:
|
||||||
|
rs_a == 1
|
||||||
|
cs_a == PACKMR
|
||||||
|
pd_a == MR
|
||||||
|
ps_a == stride to next micro-panel of A
|
||||||
|
rs_b == PACKNR
|
||||||
|
cs_b == 1
|
||||||
|
pd_b == NR
|
||||||
|
ps_b == stride to next micro-panel of B
|
||||||
|
rs_c == (no assumptions)
|
||||||
|
cs_c == (no assumptions)
|
||||||
|
*/
|
||||||
|
|
||||||
|
auto scat_size = 2 * (m + n) * sizeof(inc_t);
|
||||||
|
auto rscat_c = (inc_t*)bli_packm_alloc_ex( scat_size, BLIS_BUFFER_FOR_GEN_USE, rntm, cntl, thread );
|
||||||
|
auto rbs_c = rscat_c + m;
|
||||||
|
auto cscat_c = rbs_c + m;
|
||||||
|
auto cbs_c = cscat_c + n;
|
||||||
|
|
||||||
|
/* Fill in the scatter and block-scatter vectors. This is done single-threaded for now. */
|
||||||
|
if ( bli_thread_am_ochief( thread ) )
|
||||||
|
{
|
||||||
|
fill_scatter
|
||||||
|
(
|
||||||
|
ndim_m,
|
||||||
|
len_m,
|
||||||
|
stride_m,
|
||||||
|
MR,
|
||||||
|
off_m,
|
||||||
|
m,
|
||||||
|
rscat_c,
|
||||||
|
rbs_c
|
||||||
|
);
|
||||||
|
|
||||||
|
fill_scatter
|
||||||
|
(
|
||||||
|
ndim_n,
|
||||||
|
len_n,
|
||||||
|
stride_n,
|
||||||
|
NR,
|
||||||
|
off_n,
|
||||||
|
n,
|
||||||
|
cscat_c,
|
||||||
|
cbs_c
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Wait for the scatter vectors to be done. */
|
||||||
|
bli_thread_barrier( thread );
|
||||||
|
|
||||||
|
/* Compute number of primary and leftover components of the m and n
|
||||||
|
dimensions. */
|
||||||
|
auto n_iter = n / NR;
|
||||||
|
auto n_left = n % NR;
|
||||||
|
|
||||||
|
auto m_iter = m / MR;
|
||||||
|
auto m_left = m % MR;
|
||||||
|
|
||||||
|
if ( n_left ) ++n_iter;
|
||||||
|
if ( m_left ) ++m_iter;
|
||||||
|
|
||||||
|
/* Determine some increments used to step through A, B, and C. */
|
||||||
|
auto rstep_a = ps_a * dt_size;
|
||||||
|
auto cstep_b = ps_b * dt_size;
|
||||||
|
|
||||||
|
/* Save the virtual microkernel address and the params. */
|
||||||
|
auxinfo_t aux;
|
||||||
|
bli_auxinfo_set_ukr( (void*)gemm_ukr, &aux );
|
||||||
|
bli_auxinfo_set_params( params, &aux );
|
||||||
|
|
||||||
|
/* The 'thread' argument points to the thrinfo_t node for the 2nd (jr)
|
||||||
|
loop around the microkernel. Here we query the thrinfo_t node for the
|
||||||
|
1st (ir) loop around the microkernel. */
|
||||||
|
auto caucus = bli_thrinfo_sub_node( thread );
|
||||||
|
|
||||||
|
/* Query the number of threads and thread ids for each loop. */
|
||||||
|
auto jr_nt = bli_thread_n_way( thread );
|
||||||
|
auto jr_tid = bli_thread_work_id( thread );
|
||||||
|
auto ir_nt = bli_thread_n_way( caucus );
|
||||||
|
auto ir_tid = bli_thread_work_id( caucus );
|
||||||
|
|
||||||
|
/* Determine the thread range and increment for the 2nd and 1st loops.
|
||||||
|
NOTE: The definition of bli_thread_range_jrir() will depend on whether
|
||||||
|
slab or round-robin partitioning was requested at configure-time. */
|
||||||
|
dim_t jr_start, jr_end;
|
||||||
|
dim_t ir_start, ir_end;
|
||||||
|
dim_t jr_inc, ir_inc;
|
||||||
|
bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc );
|
||||||
|
bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc );
|
||||||
|
|
||||||
|
/* Loop over the n dimension (NR columns at a time). */
|
||||||
|
for ( auto j = jr_start; j < jr_end; j += jr_inc )
|
||||||
|
{
|
||||||
|
auto b1 = b_cast + j * cstep_b;
|
||||||
|
|
||||||
|
auto n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left );
|
||||||
|
|
||||||
|
/* Initialize our next panel of B to be the current panel of B. */
|
||||||
|
auto b2 = b1;
|
||||||
|
|
||||||
|
/* Loop over the m dimension (MR rows at a time). */
|
||||||
|
for ( auto i = ir_start; i < ir_end; i += ir_inc )
|
||||||
|
{
|
||||||
|
auto a1 = a_cast + i * rstep_a;
|
||||||
|
auto rscat_c1 = rscat_c + i * MR;
|
||||||
|
auto rbs_c1 = rbs_c + i * MR;
|
||||||
|
auto cscat_c1 = cscat_c + j * NR;
|
||||||
|
auto cbs_c1 = cbs_c + j * NR;
|
||||||
|
|
||||||
|
auto m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left );
|
||||||
|
|
||||||
|
/* Compute the addresses of the next panels of A and B. */
|
||||||
|
auto a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc );
|
||||||
|
if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) )
|
||||||
|
{
|
||||||
|
a2 = a_cast;
|
||||||
|
b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc );
|
||||||
|
if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) )
|
||||||
|
b2 = b_cast;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Save addresses of next panels of A and B to the auxinfo_t
|
||||||
|
object. */
|
||||||
|
bli_auxinfo_set_next_a( a2, &aux );
|
||||||
|
bli_auxinfo_set_next_b( b2, &aux );
|
||||||
|
|
||||||
|
auto rs_c = *rbs_c1;
|
||||||
|
auto cs_c = *cbs_c1;
|
||||||
|
|
||||||
|
if ( rs_c && cs_c )
|
||||||
|
{
|
||||||
|
auto c11 = c_cast + ( *rscat_c1 + *cscat_c1 ) * dt_size;
|
||||||
|
|
||||||
|
/* Invoke the gemm micro-kernel. */
|
||||||
|
gemm_ukr
|
||||||
|
(
|
||||||
|
m_cur,
|
||||||
|
n_cur,
|
||||||
|
k,
|
||||||
|
alpha_cast,
|
||||||
|
a1,
|
||||||
|
b1,
|
||||||
|
beta_cast,
|
||||||
|
c11, rs_c, cs_c,
|
||||||
|
&aux,
|
||||||
|
cntx
|
||||||
|
);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
/* Invoke the gemm micro-kernel. */
|
||||||
|
gemm_ukr
|
||||||
|
(
|
||||||
|
MR,
|
||||||
|
NR,
|
||||||
|
k,
|
||||||
|
alpha_cast,
|
||||||
|
a1,
|
||||||
|
b1,
|
||||||
|
zero,
|
||||||
|
&ct, rs_ct, cs_ct,
|
||||||
|
&aux,
|
||||||
|
cntx
|
||||||
|
);
|
||||||
|
|
||||||
|
/* Scatter to C. */
|
||||||
|
scatter_mxn[ dt ]
|
||||||
|
(
|
||||||
|
m_cur, n_cur,
|
||||||
|
&ct, rs_ct, cs_ct,
|
||||||
|
beta_cast,
|
||||||
|
c_cast, rscat_c1, cscat_c1
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool has_unit_stride( const std::vector<inc_t>& stride )
|
||||||
|
{
|
||||||
|
for ( auto s : stride )
|
||||||
|
if ( s == 1 )
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void tcontract( num_t dt, const std::vector<dim_t>& m, const std::vector<dim_t>& n, const std::vector<dim_t>& k,
|
||||||
|
const void* alpha, const void* a, std::vector<inc_t> rs_a, std::vector<inc_t> cs_a,
|
||||||
|
const void* b, std::vector<inc_t> rs_b, std::vector<inc_t> cs_b,
|
||||||
|
const void* beta, void* c, std::vector<inc_t> rs_c, std::vector<inc_t> cs_c )
|
||||||
|
{
|
||||||
|
if ( rs_a.size() != m.size() ||
|
||||||
|
rs_b.size() != k.size() ||
|
||||||
|
rs_c.size() != m.size() )
|
||||||
|
bli_check_error_code( BLIS_INVALID_ROW_STRIDE );
|
||||||
|
|
||||||
|
if ( cs_a.size() != k.size() ||
|
||||||
|
cs_b.size() != n.size() ||
|
||||||
|
cs_c.size() != n.size() )
|
||||||
|
bli_check_error_code( BLIS_INVALID_COL_STRIDE );
|
||||||
|
|
||||||
|
dim_t m_mat = 1;
|
||||||
|
dim_t n_mat = 1;
|
||||||
|
dim_t k_mat = 1;
|
||||||
|
for ( auto& i : m ) m_mat *= i;
|
||||||
|
for ( auto& i : n ) n_mat *= i;
|
||||||
|
for ( auto& i : k ) k_mat *= i;
|
||||||
|
|
||||||
|
auto& stride_m = has_unit_stride( rs_c ) ? rs_c : rs_a;
|
||||||
|
for ( int i = 1;i < m.size(); i++ )
|
||||||
|
for ( int j = 0;j < m.size()-i; j++ )
|
||||||
|
if ( stride_m[j] > stride_m[j+1] )
|
||||||
|
{
|
||||||
|
std::swap( rs_a[j], rs_a[j+1] );
|
||||||
|
std::swap( rs_c[j], rs_c[j+1] );
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& stride_n = has_unit_stride( cs_c ) ? cs_c : cs_b;
|
||||||
|
for ( int i = 1;i < n.size(); i++ )
|
||||||
|
for ( int j = 0;j < n.size()-i; j++ )
|
||||||
|
if ( stride_n[j] > stride_n[j+1] )
|
||||||
|
{
|
||||||
|
std::swap( cs_b[j], cs_b[j+1] );
|
||||||
|
std::swap( cs_c[j], cs_c[j+1] );
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& stride_k = has_unit_stride( cs_a ) ? cs_a : rs_b;
|
||||||
|
for ( int i = 1;i < k.size(); i++ )
|
||||||
|
for ( int j = 0;j < k.size()-i; j++ )
|
||||||
|
if ( stride_k[j] > stride_k[j+1] )
|
||||||
|
{
|
||||||
|
std::swap( cs_a[j], cs_a[j+1] );
|
||||||
|
std::swap( rs_b[j], rs_b[j+1] );
|
||||||
|
}
|
||||||
|
|
||||||
|
if ( rs_a.empty() ) rs_a.push_back( 1 );
|
||||||
|
if ( cs_a.empty() ) cs_a.push_back( 1 );
|
||||||
|
if ( rs_b.empty() ) rs_b.push_back( 1 );
|
||||||
|
if ( cs_b.empty() ) cs_b.push_back( 1 );
|
||||||
|
if ( rs_c.empty() ) rs_c.push_back( 1 );
|
||||||
|
if ( cs_c.empty() ) cs_c.push_back( 1 );
|
||||||
|
|
||||||
|
obj_t a_o, b_o, c_o;
|
||||||
|
bli_obj_create_with_attached_buffer( dt, m_mat, k_mat, const_cast<void*>(a), rs_a[0], cs_a[0], &a_o );
|
||||||
|
bli_obj_create_with_attached_buffer( dt, k_mat, n_mat, const_cast<void*>(b), rs_b[0], cs_b[0], &b_o );
|
||||||
|
bli_obj_create_with_attached_buffer( dt, m_mat, n_mat, c , rs_c[0], cs_c[0], &c_o );
|
||||||
|
|
||||||
|
packm_tensor_params_t params_a( m.size(), m.data(), rs_a.data(),
|
||||||
|
k.size(), k.data(), cs_a.data() );
|
||||||
|
packm_tensor_params_t params_b( n.size(), n.data(), cs_b.data(),
|
||||||
|
k.size(), k.data(), rs_b.data() );
|
||||||
|
gemm_tensor_params_t params_c( m.size(), m.data(), rs_c.data(),
|
||||||
|
n.size(), n.data(), cs_c.data() );
|
||||||
|
|
||||||
|
bli_obj_set_pack_fn( packm_tensor, &a_o );
|
||||||
|
bli_obj_set_pack_fn( packm_tensor, &b_o );
|
||||||
|
bli_obj_set_ker_fn( gemm_tensor, &c_o );
|
||||||
|
bli_obj_set_pack_params( ¶ms_a, &a_o );
|
||||||
|
bli_obj_set_pack_params( ¶ms_b, &b_o );
|
||||||
|
bli_obj_set_ker_params( ¶ms_c, &c_o );
|
||||||
|
|
||||||
|
obj_t alpha_o, beta_o;
|
||||||
|
bli_obj_create_1x1_with_attached_buffer( dt, const_cast<void*>(alpha), &alpha_o );
|
||||||
|
bli_obj_create_1x1_with_attached_buffer( dt, const_cast<void*>(beta), &beta_o );
|
||||||
|
|
||||||
|
rntm_t rntm;
|
||||||
|
bli_rntm_init_from_global( &rntm );
|
||||||
|
bli_rntm_disable_l3_sup( &rntm );
|
||||||
|
|
||||||
|
bli_gemm_ex( &alpha_o, &a_o, &b_o, &beta_o, &c_o, NULL, &rntm );
|
||||||
|
}
|
||||||
|
|
||||||
|
int main()
|
||||||
|
{
|
||||||
|
auto N = 5;
|
||||||
|
|
||||||
|
gint_t ndim_a = 4;
|
||||||
|
gint_t ndim_b = 4;
|
||||||
|
gint_t ndim_c = 4;
|
||||||
|
|
||||||
|
std::vector<dim_t> len_a(ndim_a, N);
|
||||||
|
std::vector<dim_t> len_b(ndim_b, N);
|
||||||
|
std::vector<dim_t> len_c(ndim_c, N);
|
||||||
|
|
||||||
|
std::vector<inc_t> stride_a(ndim_a, 1);
|
||||||
|
std::vector<inc_t> stride_b(ndim_b, 1);
|
||||||
|
std::vector<inc_t> stride_c(ndim_c, 1);
|
||||||
|
for ( gint_t i = 1; i < ndim_a; i++ )
|
||||||
|
stride_a[i] = stride_a[i-1] * len_a[i - 1];
|
||||||
|
for ( gint_t i = 1; i < ndim_b; i++ )
|
||||||
|
stride_b[i] = stride_b[i-1] * len_b[i - 1];
|
||||||
|
for ( gint_t i = 1; i < ndim_c; i++ )
|
||||||
|
stride_c[i] = stride_c[i-1] * len_c[i - 1];
|
||||||
|
|
||||||
|
std::vector<int> dim_a(ndim_a);
|
||||||
|
std::vector<int> dim_b(ndim_b);
|
||||||
|
std::vector<int> dim_c(ndim_c);
|
||||||
|
std::iota(dim_a.begin(), dim_a.end(), 0);
|
||||||
|
std::iota(dim_b.begin(), dim_b.end(), 0);
|
||||||
|
std::iota(dim_c.begin(), dim_c.end(), 0);
|
||||||
|
|
||||||
|
for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ )
|
||||||
|
do
|
||||||
|
do
|
||||||
|
do
|
||||||
|
{
|
||||||
|
auto dt = ( num_t )dt_;
|
||||||
|
|
||||||
|
auto ndim_m = (ndim_a + ndim_c - ndim_b)/2;
|
||||||
|
auto ndim_k = (ndim_a + ndim_b - ndim_c)/2;
|
||||||
|
|
||||||
|
std::vector<dim_t> m(len_a.begin(), len_a.begin()+ndim_m);
|
||||||
|
std::vector<dim_t> n(len_b.begin()+ndim_k, len_b.end());
|
||||||
|
std::vector<dim_t> k(len_b.begin(), len_b.begin()+ndim_k);
|
||||||
|
|
||||||
|
std::vector<inc_t> rs_a(stride_a.begin(), stride_a.begin()+ndim_m);
|
||||||
|
std::vector<inc_t> cs_a(stride_a.begin()+ndim_m, stride_a.end());
|
||||||
|
std::vector<inc_t> rs_b(stride_b.begin(), stride_b.begin()+ndim_k);
|
||||||
|
std::vector<inc_t> cs_b(stride_b.begin()+ndim_k, stride_b.end());
|
||||||
|
std::vector<inc_t> rs_c(stride_c.begin(), stride_c.begin()+ndim_m);
|
||||||
|
std::vector<inc_t> cs_c(stride_c.begin()+ndim_m, stride_c.end());
|
||||||
|
|
||||||
|
dim_t m_tot = 1;
|
||||||
|
dim_t n_tot = 1;
|
||||||
|
dim_t k_tot = 1;
|
||||||
|
for ( auto i : m ) m_tot *= i;
|
||||||
|
for ( auto i : n ) n_tot *= i;
|
||||||
|
for ( auto i : k ) k_tot *= i;
|
||||||
|
|
||||||
|
obj_t a, b, c, c_ref, norm;
|
||||||
|
|
||||||
|
bli_obj_create( dt, m_tot*k_tot, 1, 1, 1, &a );
|
||||||
|
bli_obj_create( dt, k_tot*n_tot, 1, 1, 1, &b );
|
||||||
|
bli_obj_create( dt, m_tot*n_tot, 1, 1, 1, &c );
|
||||||
|
bli_obj_create( dt, m_tot*n_tot, 1, 1, 1, &c_ref );
|
||||||
|
bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm );
|
||||||
|
|
||||||
|
bli_randv( &a );
|
||||||
|
bli_randv( &b );
|
||||||
|
bli_randv( &c );
|
||||||
|
bli_copyv( &c, &c_ref );
|
||||||
|
|
||||||
|
tcontract( dt, m, n, k,
|
||||||
|
bli_obj_buffer_for_const( dt, &BLIS_ONE ),
|
||||||
|
bli_obj_buffer( &a ), rs_a, cs_a,
|
||||||
|
bli_obj_buffer( &b ), rs_b, cs_b,
|
||||||
|
bli_obj_buffer_for_const( dt, &BLIS_ZERO ),
|
||||||
|
bli_obj_buffer( &c ), rs_c, cs_c );
|
||||||
|
|
||||||
|
tcontract_ref( dt, m, n, k,
|
||||||
|
bli_obj_buffer_for_const( dt, &BLIS_ONE ),
|
||||||
|
bli_obj_buffer( &a ), rs_a, cs_a,
|
||||||
|
bli_obj_buffer( &b ), rs_b, cs_b,
|
||||||
|
bli_obj_buffer_for_const( dt, &BLIS_ZERO ),
|
||||||
|
bli_obj_buffer( &c_ref ), rs_c, cs_c );
|
||||||
|
|
||||||
|
bli_subv( &c_ref, &c );
|
||||||
|
bli_normfv( &c, &norm );
|
||||||
|
|
||||||
|
double normr, normi;
|
||||||
|
bli_getsc( &norm, &normr, &normi );
|
||||||
|
|
||||||
|
printf("dt: %d, dim_a: [%d,%d,%d,%d], dim_b: [%d,%d,%d,%d], dim_c: [%d,%d,%d,%d], norm: %g\n",
|
||||||
|
dt, dim_a[0], dim_a[1], dim_a[2], dim_a[3],
|
||||||
|
dim_b[0], dim_b[1], dim_b[2], dim_b[3],
|
||||||
|
dim_c[0], dim_c[1], dim_c[2], dim_c[3],
|
||||||
|
normr / std::sqrt( bli_obj_vector_dim( &c ) ) );
|
||||||
|
|
||||||
|
bli_obj_free( &a );
|
||||||
|
bli_obj_free( &b );
|
||||||
|
bli_obj_free( &c );
|
||||||
|
bli_obj_free( &c_ref );
|
||||||
|
}
|
||||||
|
while (std::next_permutation(dim_a.begin(), dim_a.end()));
|
||||||
|
while (std::next_permutation(dim_b.begin(), dim_b.end()));
|
||||||
|
while (std::next_permutation(dim_c.begin(), dim_c.end()));
|
||||||
|
}
|
||||||
|
|
||||||
67
test/tensor_contraction/tcontract_ref.cxx
Normal file
67
test/tensor_contraction/tcontract_ref.cxx
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
#include "tcontract_ref.hpp"
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void tcontract_ref( const std::vector<dim_t>& m, const std::vector<dim_t>& n, const std::vector<dim_t>& k,
|
||||||
|
const void* alpha, const void* a, const std::vector<inc_t>& rs_a, const std::vector<inc_t>& cs_a,
|
||||||
|
const void* b, const std::vector<inc_t>& rs_b, const std::vector<inc_t>& cs_b,
|
||||||
|
const void* beta, void* c, const std::vector<inc_t>& rs_c, const std::vector<inc_t>& cs_c )
|
||||||
|
{
|
||||||
|
auto alpha_cast = *( T* )alpha;
|
||||||
|
auto beta_cast = *( T* )beta;
|
||||||
|
auto a_cast = ( T* )a;
|
||||||
|
auto b_cast = ( T* )b;
|
||||||
|
auto c_cast = ( T* )c;
|
||||||
|
|
||||||
|
for_each(m.size(), m.data(), a_cast, rs_a.data(), c_cast, rs_c.data(),
|
||||||
|
[&]
|
||||||
|
{
|
||||||
|
for_each(n.size(), n.data(), b_cast, cs_b.data(), c_cast, cs_c.data(),
|
||||||
|
[&]
|
||||||
|
{
|
||||||
|
auto ab = convert<T>(0.0);
|
||||||
|
|
||||||
|
for_each(k.size(), k.data(), a_cast, cs_a.data(), b_cast, rs_b.data(),
|
||||||
|
[&]
|
||||||
|
{
|
||||||
|
ab += (*a_cast) * (*b_cast);
|
||||||
|
});
|
||||||
|
|
||||||
|
if ( beta_cast == convert<T>(0.0) )
|
||||||
|
{
|
||||||
|
*c_cast = alpha_cast * ab;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
*c_cast = alpha_cast * ab + beta_cast * (*c_cast);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
assert(b_cast == b);
|
||||||
|
});
|
||||||
|
|
||||||
|
assert(a_cast == a);
|
||||||
|
assert(c_cast == c);
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef GENTFUNC
|
||||||
|
#define GENTFUNC(ctype,ch,op) \
|
||||||
|
static auto PASTEMAC(ch,op) = &tcontract_ref<ctype>;
|
||||||
|
|
||||||
|
INSERT_GENTFUNC_BASIC0(tcontract_ref);
|
||||||
|
|
||||||
|
static decltype(&tcontract_ref<void>) GENARRAY( tcontract_ref_impl, tcontract_ref );
|
||||||
|
|
||||||
|
void tcontract_ref( num_t dt, const std::vector<dim_t>& m, const std::vector<dim_t>& n, const std::vector<dim_t>& k,
|
||||||
|
const void* alpha, const void* a, const std::vector<inc_t>& rs_a, const std::vector<inc_t>& cs_a,
|
||||||
|
const void* b, const std::vector<inc_t>& rs_b, const std::vector<inc_t>& cs_b,
|
||||||
|
const void* beta, void* c, const std::vector<inc_t>& rs_c, const std::vector<inc_t>& cs_c )
|
||||||
|
{
|
||||||
|
tcontract_ref_impl[ dt ]
|
||||||
|
(
|
||||||
|
m, n, k,
|
||||||
|
alpha, a, rs_a, cs_a,
|
||||||
|
b, rs_b, cs_b,
|
||||||
|
beta, c, rs_c, cs_c
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
100
test/tensor_contraction/tcontract_ref.hpp
Normal file
100
test/tensor_contraction/tcontract_ref.hpp
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
#include "blis.h"
|
||||||
|
#include "complex_math.hpp"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <array>
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
inline void increment(inc_t, gint_t) {}
|
||||||
|
|
||||||
|
template <typename T, typename... Args>
|
||||||
|
void increment(inc_t n, gint_t i, T& off, const inc_t* s, Args&... args)
|
||||||
|
{
|
||||||
|
off += s[i]*n;
|
||||||
|
increment(n, i, args...);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Body, typename... Args>
|
||||||
|
void for_each_impl(gint_t ndim, const dim_t* n,
|
||||||
|
dim_t off, dim_t len,
|
||||||
|
Body& body,
|
||||||
|
Args&... args)
|
||||||
|
{
|
||||||
|
std::array<dim_t,8> i = {};
|
||||||
|
assert( ndim <= i.size() );
|
||||||
|
|
||||||
|
if ( off )
|
||||||
|
{
|
||||||
|
for ( gint_t k = 0; k < ndim; k++ )
|
||||||
|
{
|
||||||
|
i[k] = off % n[k];
|
||||||
|
off /= n[k];
|
||||||
|
increment(i[k], k, args...);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for ( dim_t pos = 0; pos < len; pos++ )
|
||||||
|
{
|
||||||
|
body();
|
||||||
|
|
||||||
|
for ( gint_t k = 0; k < ndim; k++ )
|
||||||
|
{
|
||||||
|
if ( i[k] == n[k]-1 )
|
||||||
|
{
|
||||||
|
increment(-i[k], k, args...);
|
||||||
|
i[k] = 0;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
increment(1, k, args...);
|
||||||
|
i[k]++;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Body>
|
||||||
|
void for_each(gint_t ndim, const dim_t* n,
|
||||||
|
dim_t off, dim_t len,
|
||||||
|
T& a, const inc_t* s_a,
|
||||||
|
Body&& body)
|
||||||
|
{
|
||||||
|
for_each_impl( ndim, n, off, len, body, a, s_a );
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Body>
|
||||||
|
void for_each(gint_t ndim, const dim_t* n,
|
||||||
|
dim_t off, dim_t len,
|
||||||
|
T& a, const inc_t* s_a,
|
||||||
|
T& b, const inc_t* s_b,
|
||||||
|
Body&& body)
|
||||||
|
{
|
||||||
|
for_each_impl( ndim, n, off, len, body, a, s_a, b, s_b );
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Body>
|
||||||
|
void for_each(gint_t ndim, const dim_t* n,
|
||||||
|
T& a, const inc_t* s_a,
|
||||||
|
Body&& body)
|
||||||
|
{
|
||||||
|
dim_t len = 1;
|
||||||
|
for ( gint_t i = 0;i < ndim;i++ ) len *= n[i];
|
||||||
|
for_each_impl( ndim, n, 0, len, body, a, s_a );
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename Body>
|
||||||
|
void for_each(gint_t ndim, const dim_t* n,
|
||||||
|
T& a, const inc_t* s_a,
|
||||||
|
T& b, const inc_t* s_b,
|
||||||
|
Body&& body)
|
||||||
|
{
|
||||||
|
dim_t len = 1;
|
||||||
|
for ( gint_t i = 0;i < ndim;i++ ) len *= n[i];
|
||||||
|
for_each_impl( ndim, n, 0, len, body, a, s_a, b, s_b );
|
||||||
|
}
|
||||||
|
|
||||||
|
void tcontract_ref( num_t dt, const std::vector<dim_t>& m, const std::vector<dim_t>& n, const std::vector<dim_t>& k,
|
||||||
|
const void* alpha, const void* a, const std::vector<inc_t>& rs_a, const std::vector<inc_t>& cs_a,
|
||||||
|
const void* b, const std::vector<inc_t>& rs_b, const std::vector<inc_t>& cs_b,
|
||||||
|
const void* beta, void* c, const std::vector<inc_t>& rs_c, const std::vector<inc_t>& cs_c );
|
||||||
Reference in New Issue
Block a user