mirror of
https://github.com/amd/blis.git
synced 2026-04-20 07:38:53 +00:00
Added ZTRSM AVX512 small code path
- Kernel dimensions are 4x4.
- Two kernels are implemented, Right Upper and
Right lower.
- In case of Left variants of TRSM, transpose is
induced so that Right variant kernels can be used.
- No packing is performed in these kernels.
- Changes are made in the threshold to pick ZTRSM small
code path.
- BLIS_INLINE is removed from signature of
"TRSMSMALL_KER_PROT".
- These kernels do not support "ENABLE_TRSM_PREINVERSION".
- Newly added kernels do not support conjugate
transpose.
- Added multithreading to ZTRSM small code path.
AMD-Internal: [CPUPL-4324]
Change-Id: I683b1d5239593e54f433e7f27497d72dfbd9141c
This commit is contained in:
committed by
Shubham Sharma
parent
1d983e6124
commit
b9e21e8701
@@ -4,7 +4,7 @@
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2019 - 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -71,7 +71,7 @@ err_t PASTEMAC0(opname) \
|
||||
|
||||
#define TRSMSMALL_KER_PROT( ch, opname ) \
|
||||
\
|
||||
BLIS_INLINE err_t PASTEMAC(ch,opname) \
|
||||
err_t PASTEMAC(ch,opname) \
|
||||
( \
|
||||
obj_t* AlphaObj, \
|
||||
obj_t* a, \
|
||||
|
||||
@@ -1286,11 +1286,19 @@ void bli_nthreads_optimum(
|
||||
{
|
||||
dim_t m = bli_obj_length(c);
|
||||
dim_t n = bli_obj_width(c);
|
||||
|
||||
#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM
|
||||
if ( (m <= 300) && (n <= 300) )
|
||||
n_threads_ideal = 8;
|
||||
else if ( (m <= 400) && (n <= 400) )
|
||||
n_threads_ideal = 16;
|
||||
else if ( (m <= 900) && (n <= 900) )
|
||||
n_threads_ideal = 32;
|
||||
#else
|
||||
if((m>=64) && (m<=256) && (n>=64) && (n<=256))
|
||||
{
|
||||
n_threads_ideal = 8;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if( family == BLIS_GEMMT && bli_obj_is_double(c) )
|
||||
{
|
||||
|
||||
@@ -1535,12 +1535,27 @@ void ztrsm_blis_impl
|
||||
#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM
|
||||
// This function is invoked on all architectures including 'generic'.
|
||||
// Non-AVX2+FMA3 platforms will use the kernels derived from the context.
|
||||
if (bli_cpuid_is_avx2fma3_supported() == TRUE)
|
||||
if ( bli_cpuid_is_avx2fma3_supported() == TRUE )
|
||||
{
|
||||
/* bli_ztrsm_small is performing better existing native
|
||||
* implementations for [m,n]<=1000 for single thread.
|
||||
* In case of multithread when [m,n]<=128 single thread implementation
|
||||
* is doing better than native multithread */
|
||||
typedef err_t (*ztrsm_small_ker_ft)
|
||||
(
|
||||
side_t side,
|
||||
obj_t* alpha,
|
||||
obj_t* a,
|
||||
obj_t* b,
|
||||
cntx_t* cntx,
|
||||
cntl_t* cntl,
|
||||
bool is_parallel
|
||||
);
|
||||
err_t status = BLIS_NOT_YET_IMPLEMENTED;
|
||||
|
||||
// trsm small kernel function pointer definition
|
||||
ztrsm_small_ker_ft ker_ft = NULL;
|
||||
arch_t id = bli_arch_query_id();
|
||||
bool is_parallel = bli_thread_get_is_parallel();
|
||||
dim_t dim_a = n0;
|
||||
if (blis_side == BLIS_LEFT)
|
||||
@@ -1548,30 +1563,59 @@ void ztrsm_blis_impl
|
||||
|
||||
// size of output matrix(B)
|
||||
dim_t size_b = m0*n0;
|
||||
if((!is_parallel && m0<=500 && n0<=500) ||
|
||||
(is_parallel && (m0+n0)<128) ||
|
||||
(dim_a<35 && size_b<3500))
|
||||
#if defined(BLIS_ENABLE_OPENMP) && defined(BLIS_KERNELS_ZEN4)
|
||||
if (( is_parallel ) &&
|
||||
( (dim_a > 10) && (dim_a < 2500) && (size_b > 500) && (size_b < 5e5) ) &&
|
||||
( id == BLIS_ARCH_ZEN4 ))
|
||||
{
|
||||
err_t status;
|
||||
status = bli_trsm_small
|
||||
(
|
||||
blis_side,
|
||||
&alphao,
|
||||
&ao,
|
||||
&bo,
|
||||
NULL,
|
||||
NULL,
|
||||
is_parallel
|
||||
);
|
||||
if (status == BLIS_SUCCESS)
|
||||
ker_ft = bli_trsm_small_mt_AVX512;
|
||||
}
|
||||
#endif
|
||||
if( ( ker_ft == NULL ) &&
|
||||
( ( ( !is_parallel ) &&
|
||||
( (( m0 <= 500 ) && ( n0 <= 500 )) || ( (dim_a < 75) && (size_b < 3.2e5)))) ||
|
||||
( ( is_parallel ) &&
|
||||
( (m0 + n0 < 180) || (size_b < 5000) ) )
|
||||
)
|
||||
)
|
||||
{
|
||||
switch (id)
|
||||
{
|
||||
AOCL_DTL_LOG_TRSM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *side, *m, *n);
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
|
||||
/* Finalize BLIS. */
|
||||
bli_finalize_auto();
|
||||
return;
|
||||
case BLIS_ARCH_ZEN5:
|
||||
case BLIS_ARCH_ZEN4:
|
||||
#if defined(BLIS_KERNELS_ZEN4)
|
||||
// ZTRSM AVX512 code path do not support
|
||||
// conjugate
|
||||
if (!bli_obj_has_conj(&ao))
|
||||
{
|
||||
ker_ft = bli_trsm_small_AVX512;
|
||||
}
|
||||
else
|
||||
{
|
||||
ker_ft = bli_trsm_small;
|
||||
}
|
||||
break;
|
||||
#endif // BLIS_KERNELS_ZEN4
|
||||
case BLIS_ARCH_ZEN:
|
||||
case BLIS_ARCH_ZEN2:
|
||||
case BLIS_ARCH_ZEN3:
|
||||
default:
|
||||
ker_ft = bli_trsm_small;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(ker_ft)
|
||||
{
|
||||
status = ker_ft(blis_side, &alphao, &ao, &bo, NULL, NULL, is_parallel);
|
||||
}
|
||||
if (status == BLIS_SUCCESS)
|
||||
{
|
||||
AOCL_DTL_LOG_TRSM_STATS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *side, *m, *n);
|
||||
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO);
|
||||
/* Finalize BLIS. */
|
||||
bli_finalize_auto();
|
||||
return;
|
||||
}
|
||||
} // bli_cpuid_is_avx2fma3_supported
|
||||
#endif// END of BLIS_ENABLE_SMALL_MATRIX_TRSM
|
||||
|
||||
|
||||
@@ -167,6 +167,29 @@ INSTANTIATE_TEST_SUITE_P (
|
||||
(::trsmNatUKRPrint<dcomplex,zgemmtrsm_ukr_ft>())
|
||||
);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P (
|
||||
bli_trsm_small_AVX512,
|
||||
ztrsmUkrSmall,
|
||||
::testing::Combine(
|
||||
::testing::Values(bli_trsm_small_AVX512), // ker_ptr
|
||||
::testing::Values('l', 'r'), // side
|
||||
::testing::Values('l', 'u'), // uplo
|
||||
::testing::Values('n', 'u'), // diaga
|
||||
::testing::Values('n', 't'), // transa
|
||||
::testing::Range(gtint_t(1), gtint_t(5), 1), // m
|
||||
::testing::Range(gtint_t(1), gtint_t(5), 1), // n
|
||||
::testing::Values(dcomplex{-1.4, 3.2},
|
||||
dcomplex{ 2.8, -0.5},
|
||||
dcomplex{-1.4, 0.0},
|
||||
dcomplex{ 0.0, -1.9}), // alpha
|
||||
::testing::Values(0, 10, 194), // lda_inc
|
||||
::testing::Values(0, 10, 194), // ldb_inc
|
||||
::testing::Values(false, true) // is_memory_test
|
||||
),
|
||||
(::trsmSmallUKRPrint<dcomplex, trsm_small_ker_ft>())
|
||||
);
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
@@ -5123,7 +5123,10 @@ err_t bli_trsm_small
|
||||
switch(dt)
|
||||
{
|
||||
case BLIS_DOUBLE:
|
||||
case BLIS_DCOMPLEX:
|
||||
{
|
||||
// threshold checks for these datatypes is
|
||||
// done at bla layer
|
||||
break;
|
||||
}
|
||||
case BLIS_FLOAT:
|
||||
@@ -5134,13 +5137,6 @@ err_t bli_trsm_small
|
||||
}
|
||||
break;
|
||||
}
|
||||
case BLIS_DCOMPLEX:
|
||||
{
|
||||
if((!is_parallel) && (m > 500 || n > 500)) {
|
||||
return BLIS_NOT_YET_IMPLEMENTED;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
return BLIS_NOT_YET_IMPLEMENTED;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
##Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved.##
|
||||
##Copyright (C) 2022-24, Advanced Micro Devices, Inc. All rights reserved.##
|
||||
|
||||
add_library(zen4_3
|
||||
OBJECT
|
||||
@@ -9,6 +9,7 @@ add_library(zen4_3
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_zen4_asm_32x6.c
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_zen4_asm_8x24.c
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bli_trsm_small_AVX512.c
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bli_ztrsm_small_AVX512.c
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bli_zgemm_zen4_asm_12x4.c
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bli_zero_zmm.c
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bli_zgemm_zen4_asm_4x12.c
|
||||
|
||||
@@ -152,7 +152,7 @@ typedef err_t (*trsmsmall_ker_ft)
|
||||
Pack a block of 8xk from input buffer into packed buffer
|
||||
directly or after transpose based on input params
|
||||
*/
|
||||
BLIS_INLINE void bli_dtrsm_small_pack_avx512
|
||||
void bli_dtrsm_small_pack_avx512
|
||||
(
|
||||
char side,
|
||||
dim_t size,
|
||||
@@ -406,7 +406,7 @@ BLIS_INLINE void bli_dtrsm_small_pack_avx512
|
||||
a. This helps in utilze cache line efficiently in TRSM operation
|
||||
b. store ones when input is unit diagonal
|
||||
*/
|
||||
BLIS_INLINE void dtrsm_small_pack_diag_element_avx512
|
||||
void dtrsm_small_pack_diag_element_avx512
|
||||
(
|
||||
bool is_unitdiag,
|
||||
double* a11,
|
||||
@@ -486,14 +486,14 @@ trsmsmall_ker_ft ker_fps_AVX512[4][8] =
|
||||
bli_dtrsm_small_XAltB_XAuB_AVX512,
|
||||
bli_dtrsm_small_XAltB_XAuB_AVX512,
|
||||
bli_dtrsm_small_XAutB_XAlB_AVX512},
|
||||
{NULL,
|
||||
NULL,
|
||||
NULL,
|
||||
NULL,
|
||||
NULL,
|
||||
NULL,
|
||||
NULL,
|
||||
NULL},
|
||||
{bli_ztrsm_small_AutXB_AlXB_AVX512,
|
||||
bli_ztrsm_small_AltXB_AuXB_AVX512,
|
||||
bli_ztrsm_small_AltXB_AuXB_AVX512,
|
||||
bli_ztrsm_small_AutXB_AlXB_AVX512,
|
||||
bli_ztrsm_small_XAutB_XAlB_AVX512,
|
||||
bli_ztrsm_small_XAltB_XAuB_AVX512,
|
||||
bli_ztrsm_small_XAltB_XAuB_AVX512,
|
||||
bli_ztrsm_small_XAutB_XAlB_AVX512},
|
||||
};
|
||||
/*
|
||||
* The bli_trsm_small implements a version of TRSM where A is packed and reused
|
||||
@@ -526,12 +526,12 @@ err_t bli_trsm_small_AVX512
|
||||
switch (dt)
|
||||
{
|
||||
case BLIS_DOUBLE:
|
||||
case BLIS_DCOMPLEX:
|
||||
{
|
||||
break;
|
||||
}
|
||||
case BLIS_FLOAT:
|
||||
case BLIS_SCOMPLEX:
|
||||
case BLIS_DCOMPLEX:
|
||||
default:
|
||||
{
|
||||
return BLIS_NOT_YET_IMPLEMENTED;
|
||||
@@ -602,6 +602,11 @@ err_t bli_trsm_small_mt_AVX512
|
||||
d_mr = 8, d_nr = 8;
|
||||
break;
|
||||
}
|
||||
case BLIS_DCOMPLEX:
|
||||
{
|
||||
d_mr = 4, d_nr = 4;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
return BLIS_NOT_YET_IMPLEMENTED;
|
||||
@@ -616,7 +621,7 @@ err_t bli_trsm_small_mt_AVX512
|
||||
// If dynamic-threading is enabled, calculate optimum number
|
||||
// of threads.
|
||||
// rntm will be updated with optimum number of threads.
|
||||
if (bli_obj_is_double(b))
|
||||
if (bli_obj_is_double(b) || bli_obj_is_dcomplex(b) )
|
||||
{
|
||||
bli_nthreads_optimum(a, b, b, BLIS_TRSM, &rntm);
|
||||
}
|
||||
@@ -1984,7 +1989,7 @@ err_t bli_trsm_small_mt_AVX512
|
||||
// endregion - pre/post DTRSM macros for right variants
|
||||
|
||||
// RUNN - RLTN
|
||||
BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512
|
||||
err_t bli_dtrsm_small_XAltB_XAuB_AVX512
|
||||
(
|
||||
obj_t* AlphaObj,
|
||||
obj_t* a,
|
||||
@@ -4314,7 +4319,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512
|
||||
|
||||
|
||||
// RLNN - RUTN
|
||||
BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB_AVX512
|
||||
err_t bli_dtrsm_small_XAutB_XAlB_AVX512
|
||||
(
|
||||
obj_t* AlphaObj,
|
||||
obj_t* a,
|
||||
@@ -7232,7 +7237,7 @@ zmm7 = zmm16[0] zmm15[0] zmm14[0] zmm13[0] zmm12[0] zmm11[0] zmm10[0] zmm9 [0]
|
||||
_mm_storel_pd((double *)(b11), _mm256_extractf128_pd(ymm8, 0));
|
||||
|
||||
// LLNN - LUTN
|
||||
BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB_AVX512
|
||||
err_t bli_dtrsm_small_AutXB_AlXB_AVX512
|
||||
(
|
||||
obj_t* AlphaObj,
|
||||
obj_t* a,
|
||||
@@ -9203,7 +9208,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB_AVX512
|
||||
|
||||
|
||||
// LUNN LUTN
|
||||
BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB_AVX512
|
||||
err_t bli_dtrsm_small_AltXB_AuXB_AVX512
|
||||
(
|
||||
obj_t* AlphaObj,
|
||||
obj_t* a,
|
||||
|
||||
1052
kernels/zen4/3/bli_ztrsm_small_AVX512.c
Normal file
1052
kernels/zen4/3/bli_ztrsm_small_AVX512.c
Normal file
File diff suppressed because it is too large
Load Diff
@@ -175,6 +175,10 @@ TRSMSMALL_KER_PROT( d, trsm_small_AutXB_AlXB_AVX512 )
|
||||
TRSMSMALL_KER_PROT( d, trsm_small_XAltB_XAuB_AVX512 )
|
||||
TRSMSMALL_KER_PROT( d, trsm_small_XAutB_XAlB_AVX512 )
|
||||
TRSMSMALL_KER_PROT( d, trsm_small_AltXB_AuXB_AVX512 )
|
||||
TRSMSMALL_KER_PROT( z, trsm_small_AutXB_AlXB_AVX512 )
|
||||
TRSMSMALL_KER_PROT( z, trsm_small_XAltB_XAuB_AVX512 )
|
||||
TRSMSMALL_KER_PROT( z, trsm_small_XAutB_XAlB_AVX512 )
|
||||
TRSMSMALL_KER_PROT( z, trsm_small_AltXB_AuXB_AVX512 )
|
||||
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
TRSMSMALL_PROT(trsm_small_mt_AVX512)
|
||||
|
||||
Reference in New Issue
Block a user