diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index cce4770b3..e455c095f 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -56,7 +56,7 @@ void PASTEF77(ch,blasname) \ const ftype* a, const f77_int* lda, \ const ftype* b, const f77_int* ldb, \ const ftype* beta, \ - ftype* c, const f77_int* ldc \ + ftype* c, const f77_int* ldc \ ) \ { \ trans_t blis_transa; \ @@ -140,7 +140,7 @@ void PASTEF77(ch,blasname) \ const ftype* a, const f77_int* lda, \ const ftype* b, const f77_int* ldb, \ const ftype* beta, \ - ftype* c, const f77_int* ldc \ + ftype* c, const f77_int* ldc \ ) \ { \ AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, (void*)alpha, *lda, *ldb, (void*)beta, *ldc); \ @@ -290,111 +290,310 @@ void PASTEF77(ch,blasname) \ bli_finalize_auto(); \ } -void zgemm_ - ( - const f77_char* transa, - const f77_char* transb, - const f77_int* m, - const f77_int* n, - const f77_int* k, - const dcomplex* alpha, - const dcomplex* a, const f77_int* lda, - const dcomplex* b, const f77_int* ldb, - const dcomplex* beta, - dcomplex* c, const f77_int* ldc - ) +void dgemm_ +( + const f77_char* transa, + const f77_char* transb, + const f77_int* m, + const f77_int* n, + const f77_int* k, + const double* alpha, + const double* a, const f77_int* lda, + const double* b, const f77_int* ldb, + const double* beta, + double* c, const f77_int* ldc +) { - AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, (void*)alpha, *lda, *ldb, (void*)beta, *ldc); + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, (void*)alpha, *lda, *ldb, (void*)beta, *ldc); - trans_t blis_transa; - trans_t blis_transb; - dim_t m0, n0, k0; - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) + trans_t blis_transa; + trans_t blis_transb; + dim_t m0, n0, k0; + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) - /* Initialize BLIS. */ - bli_init_auto(); + /* Initialize BLIS. */ + bli_init_auto(); - /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemm) - ( - MKSTR(z), - MKSTR(gemm), - transa, - transb, - m, - n, - k, - lda, - ldb, - ldc - ); + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemm) + ( + MKSTR(d), + MKSTR(gemm), + transa, + transb, + m, + n, + k, + lda, + ldb, + ldc + ); /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); - bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); + bli_param_map_netlib_to_blis_trans(*transa, &blis_transa); + bli_param_map_netlib_to_blis_trans(*transb, &blis_transb); - /* Typecast BLAS integers to BLIS integers. */ - bli_convert_blas_dim1( *m, m0 ); - bli_convert_blas_dim1( *n, n0 ); - bli_convert_blas_dim1( *k, k0 ); + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1(*m, m0); + bli_convert_blas_dim1(*n, n0); + bli_convert_blas_dim1(*k, k0); - /* Set the row and column strides of the matrix operands. */ - const inc_t rs_a = 1; - const inc_t cs_a = *lda; - const inc_t rs_b = 1; - const inc_t cs_b = *ldb; - const inc_t rs_c = 1; - const inc_t cs_c = *ldc; - const num_t dt = BLIS_DCOMPLEX; + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const inc_t rs_c = 1; + const inc_t cs_c = *ldc; - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t co = BLIS_OBJECT_INITIALIZER; - - dim_t m0_a, n0_a; - dim_t m0_b, n0_b; - - bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); - bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); - - bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); - bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); - - bli_obj_init_finish( dt, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao ); - bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); - bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); - - bli_obj_set_conjtrans( blis_transa, &ao ); - bli_obj_set_conjtrans( blis_transb, &bo ); - - if ((m0 <=128) && (n0 > 68) && (n0 <= 128) && (k0 <= 128)) + if (n0 == 1) + { + if (bli_is_notrans(blis_transa)) { - // induced 3m1 performs better for above case. - bli_gemmind(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - return; + bli_dgemv_unf_var2( + BLIS_NO_TRANSPOSE, + bli_extract_conj(blis_transb), + m0, k0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, + (double*)beta, + c, rs_c, + ((void*)0) + ); } else { - err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - if(status==BLIS_SUCCESS) - { - return; - } - // fall back on native path when zgemm is not handled in sup path. - bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - return; + bli_dgemv_unf_var1( + blis_transa, + bli_extract_conj(blis_transb), + k0, m0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, + (double*)beta, + c, rs_c, + ((void*)0) + ); } + return; + } + else if (m0 == 1) + { + if (bli_is_notrans(blis_transb)) + { + bli_dgemv_unf_var1( + blis_transb, + bli_extract_conj(blis_transa), + n0, k0, + (double*)alpha, + (double*)b, cs_b, rs_b, + (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, + (double*)beta, + c, cs_c, + ((void*)0) + ); + } + else + { + bli_dgemv_unf_var2( + blis_transb, + bli_extract_conj(blis_transa), + k0, n0, + (double*)alpha, + (double*)b, cs_b, rs_b, + (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, + (double*)beta, + c, cs_c, + ((void*)0) + ); + } + return; + } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + const num_t dt = BLIS_DOUBLE; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans(blis_transa, m0, k0, &m0_a, &n0_a); + bli_set_dims_with_trans(blis_transb, k0, n0, &m0_b, &n0_b); + + bli_obj_init_finish_1x1(dt, (double*)alpha, &alphao); + bli_obj_init_finish_1x1(dt, (double*)beta, &betao); + + bli_obj_init_finish(dt, m0_a, n0_a, (double*)a, rs_a, cs_a, &ao); + bli_obj_init_finish(dt, m0_b, n0_b, (double*)b, rs_b, cs_b, &bo); + bli_obj_init_finish(dt, m0, n0, (double*)c, rs_c, cs_c, &co); + + bli_obj_set_conjtrans(blis_transa, &ao); + bli_obj_set_conjtrans(blis_transb, &bo); + + //cntx_t* cntx = bli_gks_query_cntx(); + + //if ( (m0 == 128) && (n0 > 2) ) + + if (bli_is_notrans(blis_transa)) + { + if( ((m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) && (n0 > 2)) + { + err_t status = bli_dgemm_small( &alphao, + &ao, + &bo, + &betao, + &co, + NULL, //cntx, + NULL + ); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + + return; + } + } + } + + err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + if (status == BLIS_SUCCESS) + { + return; + } + // fall back on native path when dgemm is not handled in sup path. + bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + + + /* PASTEMAC(gemm, BLIS_OAPI_EX_SUF) */ + /* ( */ + /* &alphao, */ + /* &ao, */ + /* &bo, */ + /* &betao, */ + /* &co, */ + /* NULL, */ + /* NULL */ + /* ); */ + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); /* Finalize BLIS. */ bli_finalize_auto(); -} +} // end of dgemm_ + +void zgemm_ +( + const f77_char* transa, + const f77_char* transb, + const f77_int* m, + const f77_int* n, + const f77_int* k, + const dcomplex* alpha, + const dcomplex* a, const f77_int* lda, + const dcomplex* b, const f77_int* ldb, + const dcomplex* beta, + dcomplex* c, const f77_int* ldc +) +{ + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, (void*)alpha, *lda, *ldb, (void*)beta, *ldc); + + trans_t blis_transa; + trans_t blis_transb; + dim_t m0, n0, k0; + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + + /* Initialize BLIS. */ + bli_init_auto(); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemm) + ( + MKSTR(z), + MKSTR(gemm), + transa, + transb, + m, + n, + k, + lda, + ldb, + ldc + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans(*transa, &blis_transa); + bli_param_map_netlib_to_blis_trans(*transb, &blis_transb); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1(*m, m0); + bli_convert_blas_dim1(*n, n0); + bli_convert_blas_dim1(*k, k0); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const inc_t rs_c = 1; + const inc_t cs_c = *ldc; + + const num_t dt = BLIS_DCOMPLEX; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans(blis_transa, m0, k0, &m0_a, &n0_a); + bli_set_dims_with_trans(blis_transb, k0, n0, &m0_b, &n0_b); + + bli_obj_init_finish_1x1(dt, (dcomplex*)alpha, &alphao); + bli_obj_init_finish_1x1(dt, (dcomplex*)beta, &betao); + + bli_obj_init_finish(dt, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao); + bli_obj_init_finish(dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo); + bli_obj_init_finish(dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co); + + bli_obj_set_conjtrans(blis_transa, &ao); + bli_obj_set_conjtrans(blis_transb, &bo); + + if ((m0 <= 128) && (n0 > 68) && (n0 <= 128) && (k0 <= 128)) + { + // induced 3m1 performs better for above case. + bli_gemmind(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + return; + } + else + { + err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + if (status == BLIS_SUCCESS) + { + return; + } + // fall back on native path when zgemm is not handled in sup path. + bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + return; + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); +}// end of zgemm_ + + #endif #ifdef BLIS_ENABLE_BLAS -//INSERT_GENTFUNC_BLAS( gemm, gemm ) -INSERT_GENTFUNC_BLAS_SDC( gemm, gemm ) -#endif \ No newline at end of file +INSERT_GENTFUNC_BLAS_SC( gemm, gemm ) +#endif diff --git a/frame/include/bli_gentfunc_macro_defs.h b/frame/include/bli_gentfunc_macro_defs.h index 7c0ca3c87..ee05dbcc9 100644 --- a/frame/include/bli_gentfunc_macro_defs.h +++ b/frame/include/bli_gentfunc_macro_defs.h @@ -56,12 +56,12 @@ GENTFUNC( double, d, blasname, blisname ) \ GENTFUNC( scomplex, c, blasname, blisname ) \ GENTFUNC( dcomplex, z, blasname, blisname ) -#define INSERT_GENTFUNC_BLAS_SDC( blasname, blisname ) \ +#define INSERT_GENTFUNC_BLAS_SC( blasname, blisname ) \ \ GENTFUNC( float, s, blasname, blisname ) \ -GENTFUNC( double, d, blasname, blisname ) \ GENTFUNC( scomplex, c, blasname, blisname ) + #define INSERT_GENTFUNC_BLAS_CZ( blasname, blisname ) \ \ GENTFUNC( scomplex, c, blasname, blisname ) \ diff --git a/kernels/zen/3/bli_gemm_small.c b/kernels/zen/3/bli_gemm_small.c index 487d950c4..f7794e70f 100644 --- a/kernels/zen/3/bli_gemm_small.c +++ b/kernels/zen/3/bli_gemm_small.c @@ -60,16 +60,16 @@ static err_t bli_sgemm_small cntl_t* cntl ); -static err_t bli_dgemm_small - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - cntl_t* cntl - ); +/* static err_t bli_dgemm_small */ +/* ( */ +/* obj_t* alpha, */ +/* obj_t* a, */ +/* obj_t* b, */ +/* obj_t* beta, */ +/* obj_t* c, */ +/* cntx_t* cntx, */ +/* cntl_t* cntl */ +/* ); */ static err_t bli_sgemm_small_atbn ( @@ -1713,7 +1713,7 @@ static err_t bli_sgemm_small }; -static err_t bli_dgemm_small +/*static*/ err_t bli_dgemm_small ( obj_t* alpha, obj_t* a, @@ -1725,43 +1725,21 @@ static err_t bli_dgemm_small ) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); - gint_t M = bli_obj_length( c ); // number of rows of Matrix C - gint_t N = bli_obj_width( c ); // number of columns of Matrix C - gint_t K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) . - gint_t L = M * N; + gint_t M = bli_obj_length( c ); // number of rows of Matrix C + gint_t N = bli_obj_width( c ); // number of columns of Matrix C + gint_t K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) . + gint_t L = M * N; - // when N is equal to 1 call GEMV instead of GEMM - if (N == 1) - { - bli_gemv - ( - alpha, - a, - b, - beta, - c - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return BLIS_SUCCESS; - } - if (N<3) //Implemenation assumes that N is atleast 3. - { - AOCL_DTL_TRACE_EXIT_ERR( - AOCL_DTL_LEVEL_INFO, - "N < 3, cannot be processed by small gemm" - ); - return BLIS_NOT_YET_IMPLEMENTED; - } - -#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME - if( (L && K) && ((K < D_BLIS_SMALL_MATRIX_K_THRES_ROME) || ((N < BLIS_SMALL_MATRIX_THRES_ROME) && (K < BLIS_SMALL_MATRIX_THRES_ROME)))) -#else - if ((((L) < (D_BLIS_SMALL_MATRIX_THRES * D_BLIS_SMALL_MATRIX_THRES)) - || ((M < D_BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < D_BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0))) -#endif +/* #ifdef BLIS_ENABLE_SMALL_MATRIX_ROME */ +/* if( (L && K) && ((K < D_BLIS_SMALL_MATRIX_K_THRES_ROME) || ((N < BLIS_SMALL_MATRIX_THRES_ROME) && (K < BLIS_SMALL_MATRIX_THRES_ROME)))) */ +/* #else */ +/* if ((((L) < (D_BLIS_SMALL_MATRIX_THRES * D_BLIS_SMALL_MATRIX_THRES)) */ +/* || ((M < D_BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < D_BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0))) */ +/* #endif */ + if(L && K ) // Non-zero dimensions will be handled by either sup or native kernels { guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled. guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled. diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index a16914352..0816440c5 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -213,3 +213,14 @@ GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_2x4n ) GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_1x4n ) GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x2 ) GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x1 ) + + err_t bli_dgemm_small + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + );