diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index ab58842c7..cce4770b3 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -290,9 +290,111 @@ void PASTEF77(ch,blasname) \ bli_finalize_auto(); \ } +void zgemm_ + ( + const f77_char* transa, + const f77_char* transb, + const f77_int* m, + const f77_int* n, + const f77_int* k, + const dcomplex* alpha, + const dcomplex* a, const f77_int* lda, + const dcomplex* b, const f77_int* ldb, + const dcomplex* beta, + dcomplex* c, const f77_int* ldc + ) +{ + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, (void*)alpha, *lda, *ldb, (void*)beta, *ldc); + + trans_t blis_transa; + trans_t blis_transb; + dim_t m0, n0, k0; + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) + + /* Initialize BLIS. */ + bli_init_auto(); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemm) + ( + MKSTR(z), + MKSTR(gemm), + transa, + transb, + m, + n, + k, + lda, + ldb, + ldc + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + bli_convert_blas_dim1( *k, k0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const inc_t rs_c = 1; + const inc_t cs_c = *ldc; + + const num_t dt = BLIS_DCOMPLEX; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); + bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); + + bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); + bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); + + bli_obj_init_finish( dt, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); + bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); + + bli_obj_set_conjtrans( blis_transa, &ao ); + bli_obj_set_conjtrans( blis_transb, &bo ); + + if ((m0 <=128) && (n0 > 68) && (n0 <= 128) && (k0 <= 128)) + { + // induced 3m1 performs better for above case. + bli_gemmind(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + return; + } + else + { + err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + if(status==BLIS_SUCCESS) + { + return; + } + // fall back on native path when zgemm is not handled in sup path. + bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + return; + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + /* Finalize BLIS. */ + bli_finalize_auto(); +} #endif #ifdef BLIS_ENABLE_BLAS -INSERT_GENTFUNC_BLAS( gemm, gemm ) -#endif - +//INSERT_GENTFUNC_BLAS( gemm, gemm ) +INSERT_GENTFUNC_BLAS_SDC( gemm, gemm ) +#endif \ No newline at end of file diff --git a/frame/include/bli_gentfunc_macro_defs.h b/frame/include/bli_gentfunc_macro_defs.h index cdea0c145..7c0ca3c87 100644 --- a/frame/include/bli_gentfunc_macro_defs.h +++ b/frame/include/bli_gentfunc_macro_defs.h @@ -56,6 +56,11 @@ GENTFUNC( double, d, blasname, blisname ) \ GENTFUNC( scomplex, c, blasname, blisname ) \ GENTFUNC( dcomplex, z, blasname, blisname ) +#define INSERT_GENTFUNC_BLAS_SDC( blasname, blisname ) \ +\ +GENTFUNC( float, s, blasname, blisname ) \ +GENTFUNC( double, d, blasname, blisname ) \ +GENTFUNC( scomplex, c, blasname, blisname ) #define INSERT_GENTFUNC_BLAS_CZ( blasname, blisname ) \ \ diff --git a/frame/ind/bli_l3_ind.c b/frame/ind/bli_l3_ind.c index 6f8467c6b..c703be292 100644 --- a/frame/ind/bli_l3_ind.c +++ b/frame/ind/bli_l3_ind.c @@ -68,7 +68,7 @@ bool bli_l3_ind_oper_st[BLIS_NUM_IND_METHODS][BLIS_NUM_LEVEL3_OPS][2] = /* c z */ /* 3mh */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE} }, -/* 3m1 */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, +/* 3m1 */ { {FALSE,TRUE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE} }, /* 4mh */ { {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE}, {FALSE,FALSE} },