Adpative zgemm

1. 3m1 choosen for (m<=128) &  (68>n<=128) & (k<=128)
2. Default blis3.1 path for rest of the sizes.

Change-Id: I1e50dece013e72a67f1162faef5cbeb9bfbbc23a
AMD-Internal: [CPUPL-1352]
This commit is contained in:
Madan mohan Manokar
2021-01-22 12:38:08 +05:30
parent 2e7cf8d82f
commit f1ea1f1d34
3 changed files with 111 additions and 4 deletions

View File

@@ -290,9 +290,111 @@ void PASTEF77(ch,blasname) \
bli_finalize_auto(); \
}
void zgemm_
(
const f77_char* transa,
const f77_char* transb,
const f77_int* m,
const f77_int* n,
const f77_int* k,
const dcomplex* alpha,
const dcomplex* a, const f77_int* lda,
const dcomplex* b, const f77_int* ldb,
const dcomplex* beta,
dcomplex* c, const f77_int* ldc
)
{
AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, (void*)alpha, *lda, *ldb, (void*)beta, *ldc);
trans_t blis_transa;
trans_t blis_transb;
dim_t m0, n0, k0;
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO)
/* Initialize BLIS. */
bli_init_auto();
/* Perform BLAS parameter checking. */
PASTEBLACHK(gemm)
(
MKSTR(z),
MKSTR(gemm),
transa,
transb,
m,
n,
k,
lda,
ldb,
ldc
);
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( *transa, &blis_transa );
bli_param_map_netlib_to_blis_trans( *transb, &blis_transb );
/* Typecast BLAS integers to BLIS integers. */
bli_convert_blas_dim1( *m, m0 );
bli_convert_blas_dim1( *n, n0 );
bli_convert_blas_dim1( *k, k0 );
/* Set the row and column strides of the matrix operands. */
const inc_t rs_a = 1;
const inc_t cs_a = *lda;
const inc_t rs_b = 1;
const inc_t cs_b = *ldb;
const inc_t rs_c = 1;
const inc_t cs_c = *ldc;
const num_t dt = BLIS_DCOMPLEX;
obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1;
obj_t ao = BLIS_OBJECT_INITIALIZER;
obj_t bo = BLIS_OBJECT_INITIALIZER;
obj_t betao = BLIS_OBJECT_INITIALIZER_1X1;
obj_t co = BLIS_OBJECT_INITIALIZER;
dim_t m0_a, n0_a;
dim_t m0_b, n0_b;
bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a );
bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b );
bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao );
bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao );
bli_obj_init_finish( dt, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao );
bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo );
bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co );
bli_obj_set_conjtrans( blis_transa, &ao );
bli_obj_set_conjtrans( blis_transb, &bo );
if ((m0 <=128) && (n0 > 68) && (n0 <= 128) && (k0 <= 128))
{
// induced 3m1 performs better for above case.
bli_gemmind(&alphao, &ao, &bo, &betao, &co, NULL, NULL);
return;
}
else
{
err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL);
if(status==BLIS_SUCCESS)
{
return;
}
// fall back on native path when zgemm is not handled in sup path.
bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL);
return;
}
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO)
/* Finalize BLIS. */
bli_finalize_auto();
}
#endif
#ifdef BLIS_ENABLE_BLAS
INSERT_GENTFUNC_BLAS( gemm, gemm )
#endif
//INSERT_GENTFUNC_BLAS( gemm, gemm )
INSERT_GENTFUNC_BLAS_SDC( gemm, gemm )
#endif

View File

@@ -56,6 +56,11 @@ GENTFUNC( double, d, blasname, blisname ) \
GENTFUNC( scomplex, c, blasname, blisname ) \
GENTFUNC( dcomplex, z, blasname, blisname )
#define INSERT_GENTFUNC_BLAS_SDC( blasname, blisname ) \
\
GENTFUNC( float, s, blasname, blisname ) \
GENTFUNC( double, d, blasname, blisname ) \
GENTFUNC( scomplex, c, blasname, blisname )
#define INSERT_GENTFUNC_BLAS_CZ( blasname, blisname ) \
\

View File

@@ -68,7 +68,7 @@ bool bli_l3_ind_oper_st[BLIS_NUM_IND_METHODS][BLIS_NUM_LEVEL3_OPS][2] =
/* c z */
/* 3mh */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE},
{FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE} },
/* 3m1 */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE},
/* 3m1 */ { {FALSE,TRUE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE},
{FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE} },
/* 4mh */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE},
{FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE} },