Added ZTRSM AVX512 small code path

- Kernel dimensions are 4x4.
  - Two kernels are implemented, Right Upper and
    Right lower.
  - In case of Left variants of TRSM, transpose is
    induced so that Right variant kernels can be used.
  - No packing is performed in these kernels.
  - Changes are made in the threshold to pick ZTRSM small
    code path.
  - BLIS_INLINE is removed from signature of
    "TRSMSMALL_KER_PROT".
  - These kernels do not support "ENABLE_TRSM_PREINVERSION".
  - Newly added kernels do not support conjugate
    transpose.
  - Added multithreading to ZTRSM small code path.

AMD-Internal: [CPUPL-4324]
Change-Id: I683b1d5239593e54f433e7f27497d72dfbd9141c
This commit is contained in:
Shubham Sharma
2024-05-03 13:52:01 +05:30
committed by Shubham Sharma
parent 1d983e6124
commit b9e21e8701
9 changed files with 1181 additions and 48 deletions

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2019 - 2024, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -71,7 +71,7 @@ err_t PASTEMAC0(opname) \
#define TRSMSMALL_KER_PROT( ch, opname ) \
\
BLIS_INLINE err_t PASTEMAC(ch,opname) \
err_t PASTEMAC(ch,opname) \
( \
obj_t* AlphaObj, \
obj_t* a, \

View File

@@ -1286,11 +1286,19 @@ void bli_nthreads_optimum(
{
dim_t m = bli_obj_length(c);
dim_t n = bli_obj_width(c);
#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM
if ( (m <= 300) && (n <= 300) )
n_threads_ideal = 8;
else if ( (m <= 400) && (n <= 400) )
n_threads_ideal = 16;
else if ( (m <= 900) && (n <= 900) )
n_threads_ideal = 32;
#else
if((m>=64) && (m<=256) && (n>=64) && (n<=256))
{
n_threads_ideal = 8;
}
#endif
}
else if( family == BLIS_GEMMT && bli_obj_is_double(c) )
{

View File

@@ -1535,12 +1535,27 @@ void ztrsm_blis_impl
#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM
// This function is invoked on all architectures including 'generic'.
// Non-AVX2+FMA3 platforms will use the kernels derived from the context.
if (bli_cpuid_is_avx2fma3_supported() == TRUE)
if ( bli_cpuid_is_avx2fma3_supported() == TRUE )
{
/* bli_ztrsm_small is performing better existing native
* implementations for [m,n]<=1000 for single thread.
* In case of multithread when [m,n]<=128 single thread implementation
* is doing better than native multithread */
typedef err_t (*ztrsm_small_ker_ft)
(
side_t side,
obj_t* alpha,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl,
bool is_parallel
);
err_t status = BLIS_NOT_YET_IMPLEMENTED;
// trsm small kernel function pointer definition
ztrsm_small_ker_ft ker_ft = NULL;
arch_t id = bli_arch_query_id();
bool is_parallel = bli_thread_get_is_parallel();
dim_t dim_a = n0;
if (blis_side == BLIS_LEFT)
@@ -1548,30 +1563,59 @@ void ztrsm_blis_impl
// size of output matrix(B)
dim_t size_b = m0*n0;
if((!is_parallel && m0<=500 && n0<=500) ||
(is_parallel && (m0+n0)<128) ||
(dim_a<35 && size_b<3500))
#if defined(BLIS_ENABLE_OPENMP) && defined(BLIS_KERNELS_ZEN4)
if (( is_parallel ) &&
( (dim_a > 10) && (dim_a < 2500) && (size_b > 500) && (size_b < 5e5) ) &&
( id == BLIS_ARCH_ZEN4 ))
{
err_t status;
status = bli_trsm_small
(
blis_side,
&alphao,
&ao,
&bo,
NULL,
NULL,
is_parallel
);
if (status == BLIS_SUCCESS)
ker_ft = bli_trsm_small_mt_AVX512;
}
#endif
if( ( ker_ft == NULL ) &&
( ( ( !is_parallel ) &&
( (( m0 <= 500 ) && ( n0 <= 500 )) || ( (dim_a < 75) && (size_b < 3.2e5)))) ||
( ( is_parallel ) &&
( (m0 + n0 < 180) || (size_b < 5000) ) )
)
)
{
switch (id)
{
AOCL_DTL_LOG_TRSM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *side, *m, *n);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
/* Finalize BLIS. */
bli_finalize_auto();
return;
case BLIS_ARCH_ZEN5:
case BLIS_ARCH_ZEN4:
#if defined(BLIS_KERNELS_ZEN4)
// ZTRSM AVX512 code path do not support
// conjugate
if (!bli_obj_has_conj(&ao))
{
ker_ft = bli_trsm_small_AVX512;
}
else
{
ker_ft = bli_trsm_small;
}
break;
#endif // BLIS_KERNELS_ZEN4
case BLIS_ARCH_ZEN:
case BLIS_ARCH_ZEN2:
case BLIS_ARCH_ZEN3:
default:
ker_ft = bli_trsm_small;
break;
}
}
if(ker_ft)
{
status = ker_ft(blis_side, &alphao, &ao, &bo, NULL, NULL, is_parallel);
}
if (status == BLIS_SUCCESS)
{
AOCL_DTL_LOG_TRSM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *side, *m, *n);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
/* Finalize BLIS. */
bli_finalize_auto();
return;
}
} // bli_cpuid_is_avx2fma3_supported
#endif// END of BLIS_ENABLE_SMALL_MATRIX_TRSM

View File

@@ -167,6 +167,29 @@ INSTANTIATE_TEST_SUITE_P (
(::trsmNatUKRPrint<dcomplex,zgemmtrsm_ukr_ft>())
);
INSTANTIATE_TEST_SUITE_P (
bli_trsm_small_AVX512,
ztrsmUkrSmall,
::testing::Combine(
::testing::Values(bli_trsm_small_AVX512), // ker_ptr
::testing::Values('l', 'r'), // side
::testing::Values('l', 'u'), // uplo
::testing::Values('n', 'u'), // diaga
::testing::Values('n', 't'), // transa
::testing::Range(gtint_t(1), gtint_t(5), 1), // m
::testing::Range(gtint_t(1), gtint_t(5), 1), // n
::testing::Values(dcomplex{-1.4, 3.2},
dcomplex{ 2.8, -0.5},
dcomplex{-1.4, 0.0},
dcomplex{ 0.0, -1.9}), // alpha
::testing::Values(0, 10, 194), // lda_inc
::testing::Values(0, 10, 194), // ldb_inc
::testing::Values(false, true) // is_memory_test
),
(::trsmSmallUKRPrint<dcomplex, trsm_small_ker_ft>())
);
#endif

View File

@@ -5123,7 +5123,10 @@ err_t bli_trsm_small
switch(dt)
{
case BLIS_DOUBLE:
case BLIS_DCOMPLEX:
{
// threshold checks for these datatypes is
// done at bla layer
break;
}
case BLIS_FLOAT:
@@ -5134,13 +5137,6 @@ err_t bli_trsm_small
}
break;
}
case BLIS_DCOMPLEX:
{
if((!is_parallel) && (m > 500 || n > 500)) {
return BLIS_NOT_YET_IMPLEMENTED;
}
break;
}
default:
{
return BLIS_NOT_YET_IMPLEMENTED;

View File

@@ -1,4 +1,4 @@
##Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved.##
##Copyright (C) 2022-24, Advanced Micro Devices, Inc. All rights reserved.##
add_library(zen4_3
OBJECT
@@ -9,6 +9,7 @@ add_library(zen4_3
${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_zen4_asm_32x6.c
${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_zen4_asm_8x24.c
${CMAKE_CURRENT_SOURCE_DIR}/bli_trsm_small_AVX512.c
${CMAKE_CURRENT_SOURCE_DIR}/bli_ztrsm_small_AVX512.c
${CMAKE_CURRENT_SOURCE_DIR}/bli_zgemm_zen4_asm_12x4.c
${CMAKE_CURRENT_SOURCE_DIR}/bli_zero_zmm.c
${CMAKE_CURRENT_SOURCE_DIR}/bli_zgemm_zen4_asm_4x12.c

View File

@@ -152,7 +152,7 @@ typedef err_t (*trsmsmall_ker_ft)
Pack a block of 8xk from input buffer into packed buffer
directly or after transpose based on input params
*/
BLIS_INLINE void bli_dtrsm_small_pack_avx512
void bli_dtrsm_small_pack_avx512
(
char side,
dim_t size,
@@ -406,7 +406,7 @@ BLIS_INLINE void bli_dtrsm_small_pack_avx512
a. This helps in utilze cache line efficiently in TRSM operation
b. store ones when input is unit diagonal
*/
BLIS_INLINE void dtrsm_small_pack_diag_element_avx512
void dtrsm_small_pack_diag_element_avx512
(
bool is_unitdiag,
double* a11,
@@ -486,14 +486,14 @@ trsmsmall_ker_ft ker_fps_AVX512[4][8] =
bli_dtrsm_small_XAltB_XAuB_AVX512,
bli_dtrsm_small_XAltB_XAuB_AVX512,
bli_dtrsm_small_XAutB_XAlB_AVX512},
{NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL,
NULL},
{bli_ztrsm_small_AutXB_AlXB_AVX512,
bli_ztrsm_small_AltXB_AuXB_AVX512,
bli_ztrsm_small_AltXB_AuXB_AVX512,
bli_ztrsm_small_AutXB_AlXB_AVX512,
bli_ztrsm_small_XAutB_XAlB_AVX512,
bli_ztrsm_small_XAltB_XAuB_AVX512,
bli_ztrsm_small_XAltB_XAuB_AVX512,
bli_ztrsm_small_XAutB_XAlB_AVX512},
};
/*
* The bli_trsm_small implements a version of TRSM where A is packed and reused
@@ -526,12 +526,12 @@ err_t bli_trsm_small_AVX512
switch (dt)
{
case BLIS_DOUBLE:
case BLIS_DCOMPLEX:
{
break;
}
case BLIS_FLOAT:
case BLIS_SCOMPLEX:
case BLIS_DCOMPLEX:
default:
{
return BLIS_NOT_YET_IMPLEMENTED;
@@ -602,6 +602,11 @@ err_t bli_trsm_small_mt_AVX512
d_mr = 8, d_nr = 8;
break;
}
case BLIS_DCOMPLEX:
{
d_mr = 4, d_nr = 4;
break;
}
default:
{
return BLIS_NOT_YET_IMPLEMENTED;
@@ -616,7 +621,7 @@ err_t bli_trsm_small_mt_AVX512
// If dynamic-threading is enabled, calculate optimum number
// of threads.
// rntm will be updated with optimum number of threads.
if (bli_obj_is_double(b))
if (bli_obj_is_double(b) || bli_obj_is_dcomplex(b) )
{
bli_nthreads_optimum(a, b, b, BLIS_TRSM, &rntm);
}
@@ -1984,7 +1989,7 @@ err_t bli_trsm_small_mt_AVX512
// endregion - pre/post DTRSM macros for right variants
// RUNN - RLTN
BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512
err_t bli_dtrsm_small_XAltB_XAuB_AVX512
(
obj_t* AlphaObj,
obj_t* a,
@@ -4314,7 +4319,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512
// RLNN - RUTN
BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB_AVX512
err_t bli_dtrsm_small_XAutB_XAlB_AVX512
(
obj_t* AlphaObj,
obj_t* a,
@@ -7232,7 +7237,7 @@ zmm7 = zmm16[0] zmm15[0] zmm14[0] zmm13[0] zmm12[0] zmm11[0] zmm10[0] zmm9 [0]
_mm_storel_pd((double *)(b11), _mm256_extractf128_pd(ymm8, 0));
// LLNN - LUTN
BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB_AVX512
err_t bli_dtrsm_small_AutXB_AlXB_AVX512
(
obj_t* AlphaObj,
obj_t* a,
@@ -9203,7 +9208,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB_AVX512
// LUNN LUTN
BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB_AVX512
err_t bli_dtrsm_small_AltXB_AuXB_AVX512
(
obj_t* AlphaObj,
obj_t* a,

File diff suppressed because it is too large Load Diff

View File

@@ -175,6 +175,10 @@ TRSMSMALL_KER_PROT( d, trsm_small_AutXB_AlXB_AVX512 )
TRSMSMALL_KER_PROT( d, trsm_small_XAltB_XAuB_AVX512 )
TRSMSMALL_KER_PROT( d, trsm_small_XAutB_XAlB_AVX512 )
TRSMSMALL_KER_PROT( d, trsm_small_AltXB_AuXB_AVX512 )
TRSMSMALL_KER_PROT( z, trsm_small_AutXB_AlXB_AVX512 )
TRSMSMALL_KER_PROT( z, trsm_small_XAltB_XAuB_AVX512 )
TRSMSMALL_KER_PROT( z, trsm_small_XAutB_XAlB_AVX512 )
TRSMSMALL_KER_PROT( z, trsm_small_AltXB_AuXB_AVX512 )
#ifdef BLIS_ENABLE_OPENMP
TRSMSMALL_PROT(trsm_small_mt_AVX512)