diff --git a/gtestsuite/testsuite/ukr/trsm/dtrsm_ukr.cpp b/gtestsuite/testsuite/ukr/trsm/dtrsm_ukr.cpp new file mode 100644 index 000000000..c78af7946 --- /dev/null +++ b/gtestsuite/testsuite/ukr/trsm/dtrsm_ukr.cpp @@ -0,0 +1,160 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "common/testing_helpers.h" +#include "level3/ref_gemm.h" +#include "test_trsm_ukr.h" +#include "level3/trsm/test_trsm.h" + + +class DTrsmUkrTest : + public ::testing::TestWithParam> {}; // ldc_inc + + +TEST_P(DTrsmUkrTest, native) +{ + using T = double; + dgemmtrsm_ukr_ft ukr_fp = std::get<0>(GetParam()); + char storage = std::get<1>(GetParam()); + char uploa = std::get<2>(GetParam()); + char diaga = std::get<3>(GetParam()); + gtint_t m = std::get<4>(GetParam()); + gtint_t n = std::get<5>(GetParam()); + gtint_t k = std::get<6>(GetParam()); + T alpha = std::get<7>(GetParam()); + gtint_t ldc = std::get<8>(GetParam()); + + double thresh = 2 * m * testinghelpers::getEpsilon(); + test_trsm_ukr( ukr_fp, storage, uploa, diaga, m, n, k, alpha, ldc, thresh ); +} + +class DTrsmUkrTestPrint { +public: + std::string operator()( + testing::TestParamInfo> str) const{ + char storage = std::get<1>(str.param); + char uploa = std::get<2>(str.param); + char diaga = std::get<3>(str.param); + gtint_t k = std::get<6>(str.param); + double alpha = std::get<7>(str.param); + gtint_t ldc = std::get<8>(str.param); + return std::string("dgemmtrsm_ukr") + "_s" + storage + "_d" + diaga + "_u" + uploa + + "_k" + std::to_string(k) + "_a" + + (alpha > 0 ? std::to_string(int(alpha)) : std::string("m") + std::to_string(int(alpha*-1))) + + "_c" + std::to_string(ldc); + } +}; + +#ifdef BLIS_KERNELS_ZEN4 +INSTANTIATE_TEST_SUITE_P ( + bli_dgemmtrsm_l_zen4_asm_8x24, + DTrsmUkrTest, + ::testing::Combine( + ::testing::Values(bli_dgemmtrsm_l_zen4_asm_8x24), // ker_ptr + ::testing::Values('c', 'r', 'g'), // stor + ::testing::Values('l'), // uplo + ::testing::Values('u', 'n'), // diaga + ::testing::Values(8), // m + ::testing::Values(24), // n + ::testing::Values(0, 1, 2, 8, 9, 10, 500, 1000), // k + ::testing::Values(-1, -5.2, 1, 8.9), // alpha + ::testing::Values(0, 9, 53) // ldc + ), + ::DTrsmUkrTestPrint() +); + +INSTANTIATE_TEST_SUITE_P ( + bli_dgemmtrsm_u_zen4_asm_8x24, + DTrsmUkrTest, + ::testing::Combine( + ::testing::Values(bli_dgemmtrsm_u_zen4_asm_8x24), // ker_ptr + ::testing::Values('c', 'r', 'g'), // stor + ::testing::Values('u'), // uplo + ::testing::Values('u', 'n'), // diaga + ::testing::Values(8), // m + ::testing::Values(24), // n + ::testing::Values(0, 1, 2, 8, 9, 10, 500, 1000), // k + ::testing::Values(-1, -5.2, 1, 8.9), // alpha + ::testing::Values(0, 9, 53) // ldc + ), + ::DTrsmUkrTestPrint() +); +#endif + + +#ifdef BLIS_KERNELS_HASWELL +INSTANTIATE_TEST_SUITE_P ( + bli_dgemmtrsm_l_haswell_asm_6x8, + DTrsmUkrTest, + ::testing::Combine( + ::testing::Values(bli_dgemmtrsm_l_haswell_asm_6x8), // ker_ptr + ::testing::Values('c', 'r', 'g'), // stor + ::testing::Values('l'), // uplo + ::testing::Values('u', 'n'), // diaga + ::testing::Values(6), // m + ::testing::Values(8), // n + ::testing::Values(0, 1, 2, 8, 9, 10, 500, 1000), // k + ::testing::Values(-1, -5.2, 1, 8.9), // alpha + ::testing::Values(0, 9, 53) // ldc + ), + ::DTrsmUkrTestPrint() +); + +INSTANTIATE_TEST_SUITE_P ( + bli_dgemmtrsm_u_haswell_asm_6x8, + DTrsmUkrTest, + ::testing::Combine( + ::testing::Values(bli_dgemmtrsm_u_haswell_asm_6x8), // ker_ptr + ::testing::Values('c', 'r', 'g'), // stor + ::testing::Values('u'), // uplo + ::testing::Values('u', 'n'), // diaga + ::testing::Values(6), // m + ::testing::Values(8), // n + ::testing::Values(0, 1, 2, 8, 9, 10, 500, 1000), // k + ::testing::Values(-1, -5.2, 1, 8.9), // alpha + ::testing::Values(0, 9, 53) // ldc + ), + ::DTrsmUkrTestPrint() +); +#endif \ No newline at end of file diff --git a/gtestsuite/testsuite/ukr/trsm/test_trsm_ukr.h b/gtestsuite/testsuite/ukr/trsm/test_trsm_ukr.h new file mode 100644 index 000000000..d57db8491 --- /dev/null +++ b/gtestsuite/testsuite/ukr/trsm/test_trsm_ukr.h @@ -0,0 +1,214 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#pragma once + +#include "level3/trsm/trsm.h" +#include "blis.h" +#include "level3/ref_trsm.h" +#include "inc/check_error.h" +#include "common/testing_helpers.h" +#include +#include +#include "level3/trsm/test_trsm.h" + + + +template +static void test_trsm_ukr( FT ukr_fp, char storage, char uploa, char diaga, + gtint_t m, gtint_t n, gtint_t k, T alpha, + gtint_t ldc_inc, double thresh) +{ + gtint_t lda = m, ldb = n; + gtint_t ldc = ldc_inc; + + // Allocate memory for A10(k*lda) and A11(m*lda) + T* a10 = (T*)malloc( (k+m) * lda * sizeof(T) ); //col major + // Allocate memory for A01(k*ldb) and B11(m*ldb) + T* b01 = (T*)aligned_alloc(BLIS_HEAP_STRIDE_ALIGN_SIZE, (k+m) * ldb * sizeof(T)); //row major + //---------------------------------------------------------- + // Initialize vectors with random numbers. + //---------------------------------------------------------- + init_mat( a10, uploa, 'c', 'n', 3, 10, m, (k+m), lda); + init_mat( b01, uploa, 'r', 'n', 3, 10, n, (k+m), ldb); + // Get A11(A10 + sizeof(A01)) and B11(B10 + sizeof(B10)) + T* a11 = a10 + (k*lda); + T* b11 = b01 + (k*ldb); + + // make A11 triangular for trsm + testinghelpers::make_triangular( 'c', uploa, m, a11, lda ); + + T* c, *c_ref; + gtint_t rs_c, cs_c, rs_c_ref, cs_c_ref; + gtint_t size_c, size_c_ref; + + // allocate memory for C according to the storage scheme + if (storage == 'r' || storage == 'R') + { + ldc += n; + rs_c = ldc, cs_c = 1; + rs_c_ref = rs_c, cs_c_ref = cs_c; + size_c = ldc * m * sizeof(T), size_c_ref = ldc * m * sizeof(T); + c_ref = (T*)malloc( size_c_ref ); + c = (T*)malloc( size_c ); + } + else if (storage == 'c' || storage == 'C') + { + ldc += m; + cs_c = ldc, rs_c = 1; + rs_c_ref = rs_c, cs_c_ref = cs_c; + size_c = ldc * n * sizeof(T), size_c_ref = ldc * n * sizeof(T); + c_ref = (T*)malloc( size_c_ref ); + c = (T*)malloc( size_c ); + } + else + { + ldc += m; + rs_c_ref = 1, cs_c_ref = ldc; + rs_c = ldc, cs_c = ldc*ldc; + size_c = ldc * n * ldc * sizeof(T), size_c_ref = ldc * n * 1 * sizeof(T); + c_ref = (T*)malloc( size_c_ref ); + c = (T*)malloc( size_c ); + } + memset(c, 0, size_c); + memset(c_ref, 0, size_c_ref); + + // copy contents of B11 to C and C_ref + for (gtint_t i = 0; i < m; ++i) + { + for (gtint_t j = 0; j < n; ++j) + { + c[j*cs_c + i*rs_c] = b11[i*ldb + j]; + c_ref[j*cs_c_ref + i*rs_c_ref] = b11[i*ldb + j]; + } + } + + // make A11 diagonal dominant + for (gtint_t i =0;i< m; i++) + { + a11[i+i*lda] = T{float(m)}*a11[i+i*lda]; + } + + if (diaga == 'u' || diaga == 'U') + { + for (gtint_t i =0;i< m; i++) + { + a11[i+i*lda] = 1; + } + } + + //---------------------------------------------------------- + // Call BLIS function. + //---------------------------------------------------------- + ukr_fp + ( + k, + &alpha, + a10, a11, + b01, b11, + c, + rs_c, cs_c, + nullptr, nullptr + ); + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + // compensate for the trsm per-inversion + for (gtint_t i =0;i< m; i++) + { + a11[i+i*lda] = 1/a11[i+i*lda]; + } +#endif + + //---------------------------------------------------------- + // Call reference implementation to get ref results. + //---------------------------------------------------------- + if (storage == 'c' || storage == 'C') + { + testinghelpers::ref_gemm( storage, 'n', 't', m, n, k, -1, + a10, lda, b01, ldb, alpha, c_ref, ldc); + testinghelpers::ref_trsm( storage, 'l', uploa, 'n', diaga, m, n, 1, a11, + lda, c_ref, ldc ); + } + else if (storage == 'r' || storage == 'R')// row major + { + testinghelpers::ref_gemm( storage, 't', 'n', m, n, k, -1, + a10, lda, b01, ldb, alpha, c_ref, ldc); + + // convert col major A11 to row Major for TRSM + T temp = 0; + for(gtint_t i = 0; i < m; ++i) + { + for(gtint_t j = i; j< m; ++j) + { + temp = a11[i+j*lda]; + a11[i+j*lda] = a11[j+i*lda]; + a11[j+i*lda] = temp; + } + } + + testinghelpers::ref_trsm( storage, 'l', uploa, 'n', diaga, m, n, 1, a11, + lda, c_ref, ldc ); + } + else + { + testinghelpers::ref_gemm( 'c', 'n', 't', m, n, k, -1, + a10, lda, b01, ldb, alpha, c_ref, ldc); + testinghelpers::ref_trsm( 'c', 'l', uploa, 'n', diaga, m, n, 1, a11, + lda, c_ref, ldc ); + + T* c_ref_gs = (T*)malloc( ldc * n * 1 * sizeof(T) ); + memset(c_ref_gs, 0, ldc * n * 1 * sizeof(T)); + + + for (gtint_t i = 0; i < m; ++i) + { + for (gtint_t j = 0; j < n; ++j) + { + c_ref_gs[i*rs_c_ref + j*cs_c_ref] = c[i*rs_c + j*cs_c]; + } + } + free(c); + c = c_ref_gs; + } + + //---------------------------------------------------------- + // Compute component-wise error. + //---------------------------------------------------------- + computediff( storage, m, n, c, c_ref, ldc, thresh ); + + free(a10); + free(b01); + free(c); + free(c_ref); +} \ No newline at end of file