Added DTRSM Small Path AVX512 based LLNN/LUTN Variant Kernels

- 8x8 kernels are used for DTRSM SMALL
- Implemented fringe cases with below block sizes
   8x8, 8x4, 8x3, 8x2, 8x1
   4x8, 4x4, 4x3, 4x2, 4x1
   3x8, 3x4, 3x3, 3x2, 3x1
   2x8, 2x4, 2x3, 2x2, 2x1
   1x8, 1x4, 1x3, 1x2, 1x1

AMD-Internal: [CPUPL-2745]

Change-Id: I58d28912bddbaadb404052c0f3449ebbe3c97b68
This commit is contained in:
Aayush Kumar
2023-03-23 06:14:39 +00:00
parent fa024b82ad
commit 8c537b0cd5
4 changed files with 2811 additions and 157 deletions

View File

@@ -1059,10 +1059,9 @@ void dtrsm_blis_impl
{
case BLIS_ARCH_ZEN4:
#if defined(BLIS_KERNELS_ZEN4)
// check if variant is RUN[N/U] or RLT[N/U]
// this is a temporary fix, will be removed when all variants are added
if( (blis_side == BLIS_RIGHT) &&
((n0 > 300) && (m0 > 50)))
if( ((blis_side == BLIS_RIGHT) && ((n0 > 300) && (m0 > 50))) ||
((blis_side == BLIS_LEFT && ( (blis_uploa == BLIS_LOWER && blis_transa == BLIS_NO_TRANSPOSE) || (blis_uploa == BLIS_UPPER && blis_transa == BLIS_TRANSPOSE) ) ) && ((n0 != 30 && n0 !=60 ) && (m0 > 50))) )
{
ker_ft = bli_trsm_small_AVX512;
}
@@ -1089,13 +1088,13 @@ void dtrsm_blis_impl
{
case BLIS_ARCH_ZEN4:
#if defined(BLIS_KERNELS_ZEN4)
if( blis_side == BLIS_RIGHT )
if ( (blis_side == BLIS_LEFT && ( (blis_uploa == BLIS_LOWER && blis_transa == BLIS_TRANSPOSE) || (blis_uploa == BLIS_UPPER && blis_transa == BLIS_NO_TRANSPOSE) ) ))
{
ker_ft = bli_trsm_small_mt_AVX512;
ker_ft = bli_trsm_small_mt;
}
else
{
ker_ft = bli_trsm_small_mt;
ker_ft = bli_trsm_small_mt_AVX512;
}
break;
#endif// BLIS_KERNELS_ZEN4

View File

@@ -0,0 +1,129 @@
#ifdef BLIS_ENABLE_TRSM_PREINVERSION
#define DIAG_ELE_INV_OPS(a, b) (a / b)
#define DIAG_ELE_EVAL_OPS(a, b) (a * b)
#endif
#ifdef BLIS_DISABLE_TRSM_PREINVERSION
#define DIAG_ELE_INV_OPS(a, b) (a * b)
#define DIAG_ELE_EVAL_OPS(a, b) (a / b)
#endif
// reference code for LUTN
BLIS_INLINE err_t dtrsm_AutXB_ref
(
double *A,
double *B,
dim_t M,
dim_t N,
dim_t lda,
dim_t ldb,
bool unitDiagonal
)
{
dim_t i, j, k;
for (k = 0; k < M; k++)
{
double lkk_inv = 1.0;
if (!unitDiagonal)
lkk_inv = DIAG_ELE_INV_OPS(lkk_inv, A[k + k * lda]);
for (j = 0; j < N; j++)
{
B[k + j * ldb] = DIAG_ELE_EVAL_OPS(B[k + j * ldb], lkk_inv);
for (i = k + 1; i < M; i++)
{
B[i + j * ldb] -= A[i * lda + k] * B[k + j * ldb];
}
}
} // k -loop
return BLIS_SUCCESS;
}
// reference code for LLNN
BLIS_INLINE err_t dtrsm_AlXB_ref
(
double *A,
double *B,
dim_t M,
dim_t N,
dim_t lda,
dim_t ldb,
bool is_unitdiag
)
{
dim_t i, j, k;
for (k = 0; k < M; k++)
{
double lkk_inv = 1.0;
if (!is_unitdiag)
lkk_inv = DIAG_ELE_INV_OPS(lkk_inv, A[k + k * lda]);
for (j = 0; j < N; j++)
{
B[k + j * ldb] = DIAG_ELE_EVAL_OPS(B[k + j * ldb], lkk_inv);
for (i = k + 1; i < M; i++)
{
B[i + j * ldb] -= A[i + k * lda] * B[k + j * ldb];
}
}
} // k -loop
return BLIS_SUCCESS;
}
// reference code for LUNN
BLIS_INLINE err_t dtrsm_AuXB_ref
(
double *A,
double *B,
dim_t M,
dim_t N,
dim_t lda,
dim_t ldb,
bool is_unitdiag
)
{
dim_t i, j, k;
for (k = M - 1; k >= 0; k--)
{
double lkk_inv = 1.0;
if (!is_unitdiag)
lkk_inv = DIAG_ELE_INV_OPS(lkk_inv, A[k + k * lda]);
for (j = N - 1; j >= 0; j--)
{
B[k + j * ldb] = DIAG_ELE_EVAL_OPS(B[k + j * ldb], lkk_inv);
for (i = k - 1; i >= 0; i--)
{
B[i + j * ldb] -= A[i + k * lda] * B[k + j * ldb];
}
}
} // k -loop
return BLIS_SUCCESS;
} // end of function
// reference code for LLTN
BLIS_INLINE err_t dtrsm_AltXB_ref
(
double *A,
double *B,
dim_t M,
dim_t N,
dim_t lda,
dim_t ldb,
bool is_unitdiag
)
{
dim_t i, j, k;
for (k = M - 1; k >= 0; k--)
{
double lkk_inv = 1.0;
if (!is_unitdiag)
lkk_inv = DIAG_ELE_INV_OPS(lkk_inv, A[k + k * lda]);
for (j = N - 1; j >= 0; j--)
{
B[k + j * ldb] = DIAG_ELE_EVAL_OPS(B[k + j * ldb], lkk_inv);
for (i = k - 1; i >= 0; i--)
{
B[i + j * ldb] -= A[i * lda + k] * B[k + j * ldb];
}
}
} // k -loop
return BLIS_SUCCESS;
} // end of function

View File

@@ -35,6 +35,7 @@
#include "blis.h"
#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM
#include "immintrin.h"
#include "bli_trsm_small_ref.h"
#define BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL
@@ -107,18 +108,6 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB
cntl_t* cntl
);
//AX = B; A is lower triangular; transpose;
//double precision; non-unit diagonal
BLIS_INLINE err_t dtrsm_AltXB_ref
(
double *A,
double *B,
dim_t M,
dim_t N,
dim_t lda,
dim_t ldb,
bool is_unitdiag
);
/*
* ZTRSM kernel declaration
*/
@@ -248,41 +237,6 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB
#define DIAG_ELE_EVAL_OPS(a,b) (a / b)
#endif
/*
* Reference implementations
* ToDo: We can combine all these reference implementation
into a macro
*/
//A'X = B; A is upper triangular; transpose;
//non-unitDiagonal double precision
BLIS_INLINE err_t dtrsm_AutXB_ref
(
double *A,
double *B,
dim_t M,
dim_t N,
dim_t lda,
dim_t ldb,
bool unitDiagonal
)
{
dim_t i, j, k;
for (k = 0; k < M; k++)
{
double lkk_inv = 1.0;
if(!unitDiagonal) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]);
for (j = 0; j < N; j++)
{
B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb] , lkk_inv);
for (i = k+1; i < M; i++)
{
B[i + j*ldb] -= A[i*lda + k] * B[k + j*ldb];
}
}
}// k -loop
return BLIS_SUCCESS;
}// end of function
/*
* Reference implementations
* ToDo: We can combine all these reference implementation
@@ -318,37 +272,6 @@ BLIS_INLINE err_t strsm_AutXB_ref
return BLIS_SUCCESS;
}// end of function
/* TRSM scalar code for the case AX = alpha * B
* A is upper-triangular, non-unit-diagonal
* Dimensions: A: mxm X: mxn B:mxn
*/
BLIS_INLINE err_t dtrsm_AuXB_ref
(
double *A,
double *B,
dim_t M,
dim_t N,
dim_t lda,
dim_t ldb,
bool is_unitdiag
)
{
dim_t i, j, k;
for (k = M-1; k >= 0; k--)
{
double lkk_inv = 1.0;
if(!is_unitdiag) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]);
for (j = N -1; j >= 0; j--)
{
B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb],lkk_inv);
for (i = k-1; i >=0; i--)
{
B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb];
}
}
}// k -loop
return BLIS_SUCCESS;
}// end of function
/* TRSM scalar code for the case AX = alpha * B
* A is upper-triangular, non-unit-diagonal
@@ -382,37 +305,6 @@ BLIS_INLINE err_t strsm_AuXB_ref
return BLIS_SUCCESS;
}// end of function
/* TRSM scalar code for the case AX = alpha * B
* A is lower-triangular, non-unit-diagonal, no transpose
* Dimensions: A: mxm X: mxn B:mxn
*/
BLIS_INLINE err_t dtrsm_AlXB_ref
(
double *A,
double *B,
dim_t M,
dim_t N,
dim_t lda,
dim_t ldb,
bool is_unitdiag
)
{
dim_t i, j, k;
for (k = 0; k < M; k++)
{
double lkk_inv = 1.0;
if(!is_unitdiag) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]);
for (j = 0; j < N; j++)
{
B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb],lkk_inv);
for (i = k+1; i < M; i++)
{
B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb];
}
}
}// k -loop
return BLIS_SUCCESS;
}// end of function
/* TRSM scalar code for the case AX = alpha * B
* A is lower-triangular, non-unit-diagonal, no transpose
@@ -446,38 +338,6 @@ BLIS_INLINE err_t strsm_AlXB_ref
return BLIS_SUCCESS;
}// end of function
/* TRSM scalar code for the case AX = alpha * B
* A is lower-triangular, non-unit-diagonal, transpose
* Dimensions: A: mxm X: mxn B:mxn
*/
BLIS_INLINE err_t dtrsm_AltXB_ref
(
double *A,
double *B,
dim_t M,
dim_t N,
dim_t lda,
dim_t ldb,
bool is_unitdiag
)
{
dim_t i, j, k;
for (k = M-1; k >= 0; k--)
{
double lkk_inv = 1.0;
if(!is_unitdiag) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]);
for (j = N -1; j >= 0; j--)
{
B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb],lkk_inv);
for (i = k-1; i >=0; i--)
{
B[i + j*ldb] -= A[i*lda + k] * B[k + j*ldb];
}
}
}// k -loop
return BLIS_SUCCESS;
}// end of function
/* TRSM scalar code for the case AX = alpha * B
* A is lower-triangular, non-unit-diagonal, transpose
* Dimensions: A: mxm X: mxn B:mxn

File diff suppressed because it is too large Load Diff