DGEMM Optimizations for smaller dimensions

Modified dgemm_ to able to call small_gemm 16x3 kernel.
small_gemm will be called if((m + n -k) < 2000 && (m + k-n) < 2000 && n + k-m < 2000) && n > 2.
small_gemm kernel - if m or n or k = 0 we return and this case will be handled by sup or native kernel.

[CPUPL - 1376]

Change-Id: I61c2b36ad0ae4fb3dd23bc37c2b6c78556b3105b
This commit is contained in:
Kiran Varaganti
2021-02-11 11:05:42 +05:30
parent 3ab9104dae
commit a7d43cf720
4 changed files with 324 additions and 136 deletions

View File

@@ -56,7 +56,7 @@ void PASTEF77(ch,blasname) \
const ftype* a, const f77_int* lda, \
const ftype* b, const f77_int* ldb, \
const ftype* beta, \
ftype* c, const f77_int* ldc \
ftype* c, const f77_int* ldc \
) \
{ \
trans_t blis_transa; \
@@ -140,7 +140,7 @@ void PASTEF77(ch,blasname) \
const ftype* a, const f77_int* lda, \
const ftype* b, const f77_int* ldb, \
const ftype* beta, \
ftype* c, const f77_int* ldc \
ftype* c, const f77_int* ldc \
) \
{ \
AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, (void*)alpha, *lda, *ldb, (void*)beta, *ldc); \
@@ -290,111 +290,310 @@ void PASTEF77(ch,blasname) \
bli_finalize_auto(); \
}
void zgemm_
(
const f77_char* transa,
const f77_char* transb,
const f77_int* m,
const f77_int* n,
const f77_int* k,
const dcomplex* alpha,
const dcomplex* a, const f77_int* lda,
const dcomplex* b, const f77_int* ldb,
const dcomplex* beta,
dcomplex* c, const f77_int* ldc
)
void dgemm_
(
const f77_char* transa,
const f77_char* transb,
const f77_int* m,
const f77_int* n,
const f77_int* k,
const double* alpha,
const double* a, const f77_int* lda,
const double* b, const f77_int* ldb,
const double* beta,
double* c, const f77_int* ldc
)
{
AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, (void*)alpha, *lda, *ldb, (void*)beta, *ldc);
AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, (void*)alpha, *lda, *ldb, (void*)beta, *ldc);
trans_t blis_transa;
trans_t blis_transb;
dim_t m0, n0, k0;
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO)
trans_t blis_transa;
trans_t blis_transb;
dim_t m0, n0, k0;
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO)
/* Initialize BLIS. */
bli_init_auto();
/* Initialize BLIS. */
bli_init_auto();
/* Perform BLAS parameter checking. */
PASTEBLACHK(gemm)
(
MKSTR(z),
MKSTR(gemm),
transa,
transb,
m,
n,
k,
lda,
ldb,
ldc
);
/* Perform BLAS parameter checking. */
PASTEBLACHK(gemm)
(
MKSTR(d),
MKSTR(gemm),
transa,
transb,
m,
n,
k,
lda,
ldb,
ldc
);
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans( *transa, &blis_transa );
bli_param_map_netlib_to_blis_trans( *transb, &blis_transb );
bli_param_map_netlib_to_blis_trans(*transa, &blis_transa);
bli_param_map_netlib_to_blis_trans(*transb, &blis_transb);
/* Typecast BLAS integers to BLIS integers. */
bli_convert_blas_dim1( *m, m0 );
bli_convert_blas_dim1( *n, n0 );
bli_convert_blas_dim1( *k, k0 );
/* Typecast BLAS integers to BLIS integers. */
bli_convert_blas_dim1(*m, m0);
bli_convert_blas_dim1(*n, n0);
bli_convert_blas_dim1(*k, k0);
/* Set the row and column strides of the matrix operands. */
const inc_t rs_a = 1;
const inc_t cs_a = *lda;
const inc_t rs_b = 1;
const inc_t cs_b = *ldb;
const inc_t rs_c = 1;
const inc_t cs_c = *ldc;
const num_t dt = BLIS_DCOMPLEX;
/* Set the row and column strides of the matrix operands. */
const inc_t rs_a = 1;
const inc_t cs_a = *lda;
const inc_t rs_b = 1;
const inc_t cs_b = *ldb;
const inc_t rs_c = 1;
const inc_t cs_c = *ldc;
obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1;
obj_t ao = BLIS_OBJECT_INITIALIZER;
obj_t bo = BLIS_OBJECT_INITIALIZER;
obj_t betao = BLIS_OBJECT_INITIALIZER_1X1;
obj_t co = BLIS_OBJECT_INITIALIZER;
dim_t m0_a, n0_a;
dim_t m0_b, n0_b;
bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a );
bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b );
bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao );
bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao );
bli_obj_init_finish( dt, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao );
bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo );
bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co );
bli_obj_set_conjtrans( blis_transa, &ao );
bli_obj_set_conjtrans( blis_transb, &bo );
if ((m0 <=128) && (n0 > 68) && (n0 <= 128) && (k0 <= 128))
if (n0 == 1)
{
if (bli_is_notrans(blis_transa))
{
// induced 3m1 performs better for above case.
bli_gemmind(&alphao, &ao, &bo, &betao, &co, NULL, NULL);
return;
bli_dgemv_unf_var2(
BLIS_NO_TRANSPOSE,
bli_extract_conj(blis_transb),
m0, k0,
(double*)alpha,
(double*)a, rs_a, cs_a,
(double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b,
(double*)beta,
c, rs_c,
((void*)0)
);
}
else
{
err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL);
if(status==BLIS_SUCCESS)
{
return;
}
// fall back on native path when zgemm is not handled in sup path.
bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL);
return;
bli_dgemv_unf_var1(
blis_transa,
bli_extract_conj(blis_transb),
k0, m0,
(double*)alpha,
(double*)a, rs_a, cs_a,
(double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b,
(double*)beta,
c, rs_c,
((void*)0)
);
}
return;
}
else if (m0 == 1)
{
if (bli_is_notrans(blis_transb))
{
bli_dgemv_unf_var1(
blis_transb,
bli_extract_conj(blis_transa),
n0, k0,
(double*)alpha,
(double*)b, cs_b, rs_b,
(double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a,
(double*)beta,
c, cs_c,
((void*)0)
);
}
else
{
bli_dgemv_unf_var2(
blis_transb,
bli_extract_conj(blis_transa),
k0, n0,
(double*)alpha,
(double*)b, cs_b, rs_b,
(double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a,
(double*)beta,
c, cs_c,
((void*)0)
);
}
return;
}
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO)
const num_t dt = BLIS_DOUBLE;
obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1;
obj_t ao = BLIS_OBJECT_INITIALIZER;
obj_t bo = BLIS_OBJECT_INITIALIZER;
obj_t betao = BLIS_OBJECT_INITIALIZER_1X1;
obj_t co = BLIS_OBJECT_INITIALIZER;
dim_t m0_a, n0_a;
dim_t m0_b, n0_b;
bli_set_dims_with_trans(blis_transa, m0, k0, &m0_a, &n0_a);
bli_set_dims_with_trans(blis_transb, k0, n0, &m0_b, &n0_b);
bli_obj_init_finish_1x1(dt, (double*)alpha, &alphao);
bli_obj_init_finish_1x1(dt, (double*)beta, &betao);
bli_obj_init_finish(dt, m0_a, n0_a, (double*)a, rs_a, cs_a, &ao);
bli_obj_init_finish(dt, m0_b, n0_b, (double*)b, rs_b, cs_b, &bo);
bli_obj_init_finish(dt, m0, n0, (double*)c, rs_c, cs_c, &co);
bli_obj_set_conjtrans(blis_transa, &ao);
bli_obj_set_conjtrans(blis_transb, &bo);
//cntx_t* cntx = bli_gks_query_cntx();
//if ( (m0 == 128) && (n0 > 2) )
if (bli_is_notrans(blis_transa))
{
if( ((m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) && (n0 > 2))
{
err_t status = bli_dgemm_small( &alphao,
&ao,
&bo,
&betao,
&co,
NULL, //cntx,
NULL
);
if (status == BLIS_SUCCESS)
{
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
/* Finalize BLIS. */
bli_finalize_auto();
return;
}
}
}
err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL);
if (status == BLIS_SUCCESS)
{
return;
}
// fall back on native path when dgemm is not handled in sup path.
bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL);
/* PASTEMAC(gemm, BLIS_OAPI_EX_SUF) */
/* ( */
/* &alphao, */
/* &ao, */
/* &bo, */
/* &betao, */
/* &co, */
/* NULL, */
/* NULL */
/* ); */
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
/* Finalize BLIS. */
bli_finalize_auto();
}
} // end of dgemm_
void zgemm_
(
const f77_char* transa,
const f77_char* transb,
const f77_int* m,
const f77_int* n,
const f77_int* k,
const dcomplex* alpha,
const dcomplex* a, const f77_int* lda,
const dcomplex* b, const f77_int* ldb,
const dcomplex* beta,
dcomplex* c, const f77_int* ldc
)
{
AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, (void*)alpha, *lda, *ldb, (void*)beta, *ldc);
trans_t blis_transa;
trans_t blis_transb;
dim_t m0, n0, k0;
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO);
/* Initialize BLIS. */
bli_init_auto();
/* Perform BLAS parameter checking. */
PASTEBLACHK(gemm)
(
MKSTR(z),
MKSTR(gemm),
transa,
transb,
m,
n,
k,
lda,
ldb,
ldc
);
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
bli_param_map_netlib_to_blis_trans(*transa, &blis_transa);
bli_param_map_netlib_to_blis_trans(*transb, &blis_transb);
/* Typecast BLAS integers to BLIS integers. */
bli_convert_blas_dim1(*m, m0);
bli_convert_blas_dim1(*n, n0);
bli_convert_blas_dim1(*k, k0);
/* Set the row and column strides of the matrix operands. */
const inc_t rs_a = 1;
const inc_t cs_a = *lda;
const inc_t rs_b = 1;
const inc_t cs_b = *ldb;
const inc_t rs_c = 1;
const inc_t cs_c = *ldc;
const num_t dt = BLIS_DCOMPLEX;
obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1;
obj_t ao = BLIS_OBJECT_INITIALIZER;
obj_t bo = BLIS_OBJECT_INITIALIZER;
obj_t betao = BLIS_OBJECT_INITIALIZER_1X1;
obj_t co = BLIS_OBJECT_INITIALIZER;
dim_t m0_a, n0_a;
dim_t m0_b, n0_b;
bli_set_dims_with_trans(blis_transa, m0, k0, &m0_a, &n0_a);
bli_set_dims_with_trans(blis_transb, k0, n0, &m0_b, &n0_b);
bli_obj_init_finish_1x1(dt, (dcomplex*)alpha, &alphao);
bli_obj_init_finish_1x1(dt, (dcomplex*)beta, &betao);
bli_obj_init_finish(dt, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao);
bli_obj_init_finish(dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo);
bli_obj_init_finish(dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co);
bli_obj_set_conjtrans(blis_transa, &ao);
bli_obj_set_conjtrans(blis_transb, &bo);
if ((m0 <= 128) && (n0 > 68) && (n0 <= 128) && (k0 <= 128))
{
// induced 3m1 performs better for above case.
bli_gemmind(&alphao, &ao, &bo, &betao, &co, NULL, NULL);
return;
}
else
{
err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL);
if (status == BLIS_SUCCESS)
{
return;
}
// fall back on native path when zgemm is not handled in sup path.
bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL);
return;
}
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
/* Finalize BLIS. */
bli_finalize_auto();
}// end of zgemm_
#endif
#ifdef BLIS_ENABLE_BLAS
//INSERT_GENTFUNC_BLAS( gemm, gemm )
INSERT_GENTFUNC_BLAS_SDC( gemm, gemm )
#endif
INSERT_GENTFUNC_BLAS_SC( gemm, gemm )
#endif

View File

@@ -56,12 +56,12 @@ GENTFUNC( double, d, blasname, blisname ) \
GENTFUNC( scomplex, c, blasname, blisname ) \
GENTFUNC( dcomplex, z, blasname, blisname )
#define INSERT_GENTFUNC_BLAS_SDC( blasname, blisname ) \
#define INSERT_GENTFUNC_BLAS_SC( blasname, blisname ) \
\
GENTFUNC( float, s, blasname, blisname ) \
GENTFUNC( double, d, blasname, blisname ) \
GENTFUNC( scomplex, c, blasname, blisname )
#define INSERT_GENTFUNC_BLAS_CZ( blasname, blisname ) \
\
GENTFUNC( scomplex, c, blasname, blisname ) \

View File

@@ -60,16 +60,16 @@ static err_t bli_sgemm_small
cntl_t* cntl
);
static err_t bli_dgemm_small
(
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl
);
/* static err_t bli_dgemm_small */
/* ( */
/* obj_t* alpha, */
/* obj_t* a, */
/* obj_t* b, */
/* obj_t* beta, */
/* obj_t* c, */
/* cntx_t* cntx, */
/* cntl_t* cntl */
/* ); */
static err_t bli_sgemm_small_atbn
(
@@ -1713,7 +1713,7 @@ static err_t bli_sgemm_small
};
static err_t bli_dgemm_small
/*static*/ err_t bli_dgemm_small
(
obj_t* alpha,
obj_t* a,
@@ -1725,43 +1725,21 @@ static err_t bli_dgemm_small
)
{
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO);
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO);
gint_t M = bli_obj_length( c ); // number of rows of Matrix C
gint_t N = bli_obj_width( c ); // number of columns of Matrix C
gint_t K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) .
gint_t L = M * N;
gint_t M = bli_obj_length( c ); // number of rows of Matrix C
gint_t N = bli_obj_width( c ); // number of columns of Matrix C
gint_t K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) .
gint_t L = M * N;
// when N is equal to 1 call GEMV instead of GEMM
if (N == 1)
{
bli_gemv
(
alpha,
a,
b,
beta,
c
);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
return BLIS_SUCCESS;
}
if (N<3) //Implemenation assumes that N is atleast 3.
{
AOCL_DTL_TRACE_EXIT_ERR(
AOCL_DTL_LEVEL_INFO,
"N < 3, cannot be processed by small gemm"
);
return BLIS_NOT_YET_IMPLEMENTED;
}
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
if( (L && K) && ((K < D_BLIS_SMALL_MATRIX_K_THRES_ROME) || ((N < BLIS_SMALL_MATRIX_THRES_ROME) && (K < BLIS_SMALL_MATRIX_THRES_ROME))))
#else
if ((((L) < (D_BLIS_SMALL_MATRIX_THRES * D_BLIS_SMALL_MATRIX_THRES))
|| ((M < D_BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < D_BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0)))
#endif
/* #ifdef BLIS_ENABLE_SMALL_MATRIX_ROME */
/* if( (L && K) && ((K < D_BLIS_SMALL_MATRIX_K_THRES_ROME) || ((N < BLIS_SMALL_MATRIX_THRES_ROME) && (K < BLIS_SMALL_MATRIX_THRES_ROME)))) */
/* #else */
/* if ((((L) < (D_BLIS_SMALL_MATRIX_THRES * D_BLIS_SMALL_MATRIX_THRES)) */
/* || ((M < D_BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < D_BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0))) */
/* #endif */
if(L && K ) // Non-zero dimensions will be handled by either sup or native kernels
{
guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled.
guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled.

View File

@@ -213,3 +213,14 @@ GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_2x4n )
GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_1x4n )
GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x2 )
GEMMSUP_KER_PROT( dcomplex, z, gemmsup_rv_zen_asm_3x1 )
err_t bli_dgemm_small
(
obj_t* alpha,
obj_t* a,
obj_t* b,
obj_t* beta,
obj_t* c,
cntx_t* cntx,
cntl_t* cntl
);