Added new ZTRSM small code path for ZEN5

- Added new ZTRSM kernels for right and left variants.
- Kernel dimensions are 12x4.
- 12x4 ZGEMM SUP kernels are used internally
  for solving GEMM subproblem.
- These kernels do not support conjugate transpose.
- Only column major inputs are supported.
- Tuned thresholds to pick efficent code path for ZEN5.

AMD-Internal: [CPUPL-6356]
Change-Id: I33ba3d337b0fcd972ca9cfe4668cb23d2b279b6e
This commit is contained in:
Shubham Sharma
2025-02-06 18:01:10 +05:30
parent 2e687d8847
commit f8c83fedb6
8 changed files with 3545 additions and 1648 deletions

View File

@@ -1666,51 +1666,103 @@ void ztrsm_blis_impl
// trsm small kernel function pointer definition
ztrsm_small_ker_ft ker_ft = NULL;
arch_t id = bli_arch_query_id();
bool is_parallel = bli_thread_get_is_parallel();
dim_t dim_a = n0;
arch_t id = bli_arch_query_id();
bool is_parallel = bli_thread_get_is_parallel();
dim_t dim_a = n0;
(void) dim_a; //avoid unused warning for zen2/3
if (blis_side == BLIS_LEFT)
dim_a = m0;
// size of output matrix(B)
dim_t size_b = m0*n0;
#if defined(BLIS_ENABLE_OPENMP) && defined(BLIS_KERNELS_ZEN4)
if (( is_parallel ) &&
( (dim_a > 10) && (dim_a < 2500) && (size_b > 500) && (size_b < 5e5) ) &&
( id == BLIS_ARCH_ZEN4 ))
#if defined(BLIS_ENABLE_OPENMP)
switch (id)
{
if (!bli_obj_has_conj(&ao))
case BLIS_ARCH_ZEN5:
#if defined(BLIS_KERNELS_ZEN5)
if (( is_parallel ) &&
( (dim_a > 10) && (dim_a < 2500) && (size_b > 500) && (size_b < 5e5) ))
{
ker_ft = bli_trsm_small_mt_AVX512;
if (!bli_obj_has_conj(&ao)) // if transa == 'C', go to native code path
{
ker_ft = bli_trsm_small_mt_ZEN5; // 12x4 non fused kernel for ZEN5
}
}
else
break;
#endif //BLIS_KERNELS_ZEN5
case BLIS_ARCH_ZEN4:
#if defined(BLIS_KERNELS_ZEN4)
if (( is_parallel ) &&
( (dim_a > 10) && (dim_a < 2500) && (size_b > 500) && (size_b < 5e5) ))
{
ker_ft = bli_trsm_small_mt;
if (!bli_obj_has_conj(&ao))
{
ker_ft = bli_trsm_small_mt_AVX512; // 4x4 fused kernel for ZEN4
}
else
{
ker_ft = bli_trsm_small_mt;
}
}
break;
#endif //BLIS_KERNELS_ZEN4
default:
break;
}
#endif
if( ( ker_ft == NULL ) &&
( ( ( !is_parallel ) &&
( (( m0 <= 500 ) && ( n0 <= 500 )) || ( (dim_a < 75) && (size_b < 3.2e5)))) ||
( ( is_parallel ) &&
( (m0 + n0 < 180) || (size_b < 5000) ) )
( ( !is_parallel ) ||
( ( is_parallel ) &&
( (m0 + n0 < 180) || (size_b < 5000) ) )
)
)
{
switch (id)
{
case BLIS_ARCH_ZEN5:
#if defined(BLIS_KERNELS_ZEN5)
if (bli_obj_has_conj(&ao))
break; // conjugate not supported in AVX512 small code path
// Decision logic tuned using Powell optimizer from scikit-learn
if ( blis_side == BLIS_LEFT )
{
if ( m0 <= 88 )
{
ker_ft = bli_trsm_small_AVX512;
}
else if ( (log10(n0) + (0.15*log10(m0)) ) < 2.924 )
{
ker_ft = bli_trsm_small_ZEN5;
}
}
else //if ( blis_side == BLIS_RIGHT )
{
if ( (log10(m0) + (2.8*log10(n0)) ) < 6 )
{
ker_ft = bli_trsm_small_AVX512;
}
else if ( (log10(m0) + (1.058*log10(n0)) ) < 5.373 )
{
ker_ft = bli_trsm_small_ZEN5;
}
}
break;
#endif //BLIS_KERNELS_ZEN5
case BLIS_ARCH_ZEN4:
#if defined(BLIS_KERNELS_ZEN4)
// ZTRSM AVX512 code path do not support
// conjugate
if (!bli_obj_has_conj(&ao))
if ((( m0 <= 500 ) && ( n0 <= 500 )) || ( (dim_a < 75) && (size_b < 3.2e5)))
{
ker_ft = bli_trsm_small_AVX512;
}
else
{
ker_ft = bli_trsm_small;
// ZTRSM AVX512 code path do not support
// conjugate
if (!bli_obj_has_conj(&ao))
{
ker_ft = bli_trsm_small_AVX512;
}
else
{
ker_ft = bli_trsm_small;
}
}
break;
#endif // BLIS_KERNELS_ZEN4

View File

@@ -65,6 +65,7 @@
#define K_bli_dgemv_t_zen_int_mx3_avx512 1
#define K_bli_dgemv_t_zen_int_mx2_avx512 1
#define K_bli_dgemv_t_zen_int_mx1_avx512 1
#define K_bli_ztrsm_small_ZEN5 1
#define AOCL_50

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2023 - 2024, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2023 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -169,6 +169,52 @@ INSTANTIATE_TEST_SUITE_P(
::trsmGenericPrint<dcomplex>()
);
/**
* @brief Test ZTRSM small avx512 path all fringe cases
* Kernel size for avx512 small path is 4x4 and 12x4, testing in range of
* 1 to 24 ensures all finge cases are being tested (1 to 12 for trsm and 12 to 24 for gemm subproblem).
*/
INSTANTIATE_TEST_SUITE_P(
Small_AVX512_fringe,
ztrsmGeneric,
::testing::Combine(
::testing::Values('c'), // storage format
::testing::Values('l','r'), // side l:left, r:right
::testing::Values('u','l'), // uplo u:upper, l:lower
::testing::Values('n', 't'), // transa
::testing::Values('n','u'), // diaga , n=nonunit u=unit
::testing::Range(gtint_t(1), gtint_t(25), 2), // m
::testing::Range(gtint_t(1), gtint_t(25), 2), // n
::testing::Values(dcomplex{2.0,-3.4}), // alpha
::testing::Values(gtint_t(56)), // increment to the leading dim of a
::testing::Values(gtint_t(33)) // increment to the leading dim of b
),
::trsmGenericPrint<dcomplex>()
);
/**
* @brief Test ZTRSM small avx512 path, Same test also covers small_zen5 kernels when run with
* BLIS_ARCH_TYPE=zen5
*/
INSTANTIATE_TEST_SUITE_P(
Small_AVX512,
ztrsmGeneric,
::testing::Combine(
::testing::Values('c'), // storage format
::testing::Values('l','r'), // side l:left, r:right
::testing::Values('u','l'), // uplo u:upper, l:lower
::testing::Values('n', 't'), // transa
::testing::Values('n','u'), // diaga , n=nonunit u=unit
::testing::Values(17, 500), // m
::testing::Values(48, 500), // n
::testing::Values(dcomplex{2.0,-3.4}), // alpha
::testing::Values(gtint_t(54)), // increment to the leading dim of a
::testing::Values(gtint_t(37)) // increment to the leading dim of b
),
::trsmGenericPrint<dcomplex>()
);
/**
* @brief Test ZTRSM with differnt values of alpha
* code paths covered:

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2024 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -132,6 +132,74 @@ TEST_P( ztrsmGenericSmall, UKR )
test_trsm_small_ukr<T, trsm_small_ker_ft>( ukr_fp, side, uploa, diaga, transa, m, n, alpha, lda, ldb, thresh, is_memory_test, BLIS_DCOMPLEX);
}
#if defined(BLIS_KERNELS_ZEN5) && defined(GTEST_AVX512)
#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM
#ifdef K_bli_ztrsm_small_ZEN5
INSTANTIATE_TEST_SUITE_P(
bli_trsm_small_ZEN5_r,
ztrsmGenericSmall,
::testing::Combine(
::testing::Values(bli_trsm_small_ZEN5), // ker_ptr
::testing::Values('r'), // side
::testing::Values('l', 'u'), // uplo
::testing::Values('n', 'u'), // diaga
::testing::Values('n', 't'), // transa
::testing::Range(gtint_t(1), gtint_t(13), 1), // m ( 1 to 12)
::testing::Range(gtint_t(1), gtint_t(5), 1), // n ( 1 to 4 )
::testing::Values(dcomplex{-1.4, 3.2},
dcomplex{ 0.0, -1.9}), // alpha
::testing::Values(0, 194), // lda_inc
::testing::Values(0, 194), // ldb_inc
::testing::Values(false, true) // is_memory_test
),
(::trsmSmallUKRPrint<dcomplex, trsm_small_ker_ft>())
);
INSTANTIATE_TEST_SUITE_P(
bli_trsm_small_ZEN5_l,
ztrsmGenericSmall,
::testing::Combine(
::testing::Values(bli_trsm_small_ZEN5), // ker_ptr
::testing::Values('l'), // side
::testing::Values('l', 'u'), // uplo
::testing::Values('n', 'u'), // diaga
::testing::Values('n', 't'), // transa
::testing::Range(gtint_t(1), gtint_t(5), 1), // m
::testing::Range(gtint_t(1), gtint_t(13), 1), // n
::testing::Values(dcomplex{-1.4, 3.2},
dcomplex{ 0.0, -1.9}), // alpha
::testing::Values(0, 194), // lda_inc
::testing::Values(0, 194), // ldb_inc
::testing::Values(false, true) // is_memory_test
),
(::trsmSmallUKRPrint<dcomplex, trsm_small_ker_ft>())
);
INSTANTIATE_TEST_SUITE_P(
bli_trsm_small_ZEN5_gemm,
ztrsmGenericSmall,
::testing::Combine(
::testing::Values(bli_trsm_small_ZEN5), // ker_ptr
::testing::Values('l', 'r'), // side
::testing::Values('l', 'u'), // uplo
::testing::Values('n', 'u'), // diaga
::testing::Values('n', 't'), // transa
::testing::Range(gtint_t(12), gtint_t(48), 5), // m
::testing::Range(gtint_t(12), gtint_t(48), 5), // n
::testing::Values(dcomplex{-1.4, 3.2}), // alpha
::testing::Values(0, 10), // lda_inc
::testing::Values(0, 10), // ldb_inc
::testing::Values(false, true) // is_memory_test
),
(::trsmSmallUKRPrint<dcomplex, trsm_small_ker_ft>())
);
#endif // K_bli_ztrsm_small_ZEN5
#endif // BLIS_ENABLE_SMALL_MATRIX_TRSM
#endif // defined(BLIS_KERNELS_ZEN5) && defined(GTEST_AVX512)
#if defined(BLIS_KERNELS_ZEN4) && defined(GTEST_AVX512)
#ifdef K_bli_zgemmtrsm_l_zen4_asm_4x12
INSTANTIATE_TEST_SUITE_P(

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -75,6 +75,11 @@ TRSMSMALL_KER_PROT( d, trsm_small_XAutB_XAlB_ZEN5 )
TRSMSMALL_KER_PROT( d, trsm_small_AltXB_AuXB_ZEN5 )
TRSMSMALL_KER_PROT( d, trsm_small_AutXB_AlXB_ZEN5 )
TRSMSMALL_KER_PROT( z, trsm_small_XAltB_XAuB_ZEN5 )
TRSMSMALL_KER_PROT( z, trsm_small_XAutB_XAlB_ZEN5 )
TRSMSMALL_KER_PROT( z, trsm_small_AltXB_AuXB_ZEN5 )
TRSMSMALL_KER_PROT( z, trsm_small_AutXB_AlXB_ZEN5 )
#ifdef BLIS_ENABLE_OPENMP
err_t bli_trsm_small_mt_ZEN5
(