diff --git a/frame/3/bli_l3.h b/frame/3/bli_l3.h index 4dc8c5aa9..b64da054c 100644 --- a/frame/3/bli_l3.h +++ b/frame/3/bli_l3.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -97,4 +97,4 @@ #include "bli_trmm.h" #include "bli_trmm3.h" #include "bli_trsm.h" - +#include "bli_gemmt.h" diff --git a/frame/3/bli_l3_cntl.c b/frame/3/bli_l3_cntl.c index f6bfbedbb..54a2804bf 100644 --- a/frame/3/bli_l3_cntl.c +++ b/frame/3/bli_l3_cntl.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -55,7 +55,8 @@ void bli_l3_cntl_create_if { if ( family == BLIS_GEMM || family == BLIS_HERK || - family == BLIS_TRMM ) + family == BLIS_TRMM || + family == BLIS_GEMMT) { *cntl_use = bli_gemm_cntl_create( rntm, family, schema_a, schema_b ); } diff --git a/frame/3/bli_l3_oapi.c b/frame/3/bli_l3_oapi.c index aec1e43d4..16a8ef8d1 100644 --- a/frame/3/bli_l3_oapi.c +++ b/frame/3/bli_l3_oapi.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -124,6 +124,68 @@ void PASTEMAC(opname,EX_SUF) \ GENFRONT( gemm ) +#undef GENFRONT +#define GENFRONT( opname ) \ +\ +void PASTEMAC(opname,EX_SUF) \ + ( \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + obj_t* beta, \ + obj_t* c \ + BLIS_OAPI_EX_PARAMS \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2) \ + bli_init_once(); \ +\ + BLIS_OAPI_EX_DECLS \ + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_2, alpha, a, b, beta, c); \ +\ + /* If C has a zero dimension, return early. */ \ + if ( bli_obj_has_zero_dim( c ) ) {\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ + return; \ + }\ +\ + /* if alpha or A or B has a zero dimension, \ + scale C by beta and return early. */ \ + if ( bli_obj_equals( alpha, &BLIS_ZERO ) || \ + bli_obj_has_zero_dim( a ) || \ + bli_obj_has_zero_dim( b ) ) \ + {\ + bli_scalm( beta, c ); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2) \ + return;\ + }\ +\ + /* Only proceed with an induced method if each of the operands have a + complex storage datatype. NOTE: Allowing precisions to vary while + using 1m, which is what we do here, is unique to gemm; other level-3 + operations use 1m only if all storage datatypes are equal (and they + ignore the computation precision). If any operands are real, skip the + induced method chooser function and proceed directly with native + execution. */ \ + if ( bli_obj_is_complex( c ) && \ + bli_obj_is_complex( a ) && \ + bli_obj_is_complex( b ) ) \ + { \ + /* Invoke the operation's "ind" function--its induced method front-end. + For complex problems, it calls the highest priority induced method + that is available (ie: implemented and enabled), and if none are + enabled, it calls native execution. (For real problems, it calls + the operation's native execution interface.) */ \ + PASTEMAC(opname,ind)( alpha, a, b, beta, c, cntx, rntm ); \ + } \ + else \ + { \ + PASTEMAC(opname,nat)( alpha, a, b, beta, c, cntx, rntm ); \ + } \ + \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ +} +GENFRONT( gemmt ) #undef GENFRONT #define GENFRONT( opname ) \ diff --git a/frame/3/bli_l3_oapi.h b/frame/3/bli_l3_oapi.h index 4f9f20608..fcbc9dec4 100644 --- a/frame/3/bli_l3_oapi.h +++ b/frame/3/bli_l3_oapi.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -51,6 +52,7 @@ BLIS_EXPORT_BLIS void PASTEMAC(opname,EX_SUF) \ ); GENPROT( gemm ) +GENPROT( gemmt ) GENPROT( her2k ) GENPROT( syr2k ) diff --git a/frame/3/bli_l3_oft.h b/frame/3/bli_l3_oft.h index c182ed56c..e7c8dcca3 100644 --- a/frame/3/bli_l3_oft.h +++ b/frame/3/bli_l3_oft.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -57,6 +58,7 @@ typedef void (*PASTECH(opname,_oft)) \ ); GENTDEF( gemm ) +GENTDEF( gemmt ) GENTDEF( her2k ) GENTDEF( syr2k ) diff --git a/frame/3/bli_l3_tapi.c b/frame/3/bli_l3_tapi.c index 7b7f758ab..b66fef882 100644 --- a/frame/3/bli_l3_tapi.c +++ b/frame/3/bli_l3_tapi.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -100,7 +100,7 @@ void PASTEMAC2(ch,opname,EX_SUF) \ } INSERT_GENTFUNC_BASIC0( gemm ) - +INSERT_GENTFUNC_BASIC0( gemmt ) #undef GENTFUNC #define GENTFUNC( ctype, ch, opname, struca ) \ diff --git a/frame/3/bli_l3_tapi.h b/frame/3/bli_l3_tapi.h index a809c2a68..c9c87585d 100644 --- a/frame/3/bli_l3_tapi.h +++ b/frame/3/bli_l3_tapi.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -56,7 +57,7 @@ BLIS_EXPORT_BLIS void PASTEMAC2(ch,opname,EX_SUF) \ ); INSERT_GENTPROT_BASIC0( gemm ) - +INSERT_GENTPROT_BASIC0( gemmt ) #undef GENTPROT #define GENTPROT( ctype, ch, opname ) \ diff --git a/frame/3/gemm/bli_gemm_cntl.c b/frame/3/gemm/bli_gemm_cntl.c index d7cd0a92c..1e50fd7c7 100644 --- a/frame/3/gemm/bli_gemm_cntl.c +++ b/frame/3/gemm/bli_gemm_cntl.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -65,6 +65,7 @@ cntl_t* bli_gemmbp_cntl_create if ( family == BLIS_GEMM ) macro_kernel_fp = bli_gemm_ker_var2; else if ( family == BLIS_HERK ) macro_kernel_fp = bli_herk_x_ker_var2; else if ( family == BLIS_TRMM ) macro_kernel_fp = bli_trmm_xx_ker_var2; + else if ( family == BLIS_GEMMT ) macro_kernel_fp = bli_gemmt_ker_var2; else /* should never execute */ macro_kernel_fp = NULL; packa_fp = bli_packm_blk_var1; diff --git a/frame/3/gemmt/bli_gemmt.h b/frame/3/gemmt/bli_gemmt.h new file mode 100644 index 000000000..5b599fe8c --- /dev/null +++ b/frame/3/gemmt/bli_gemmt.h @@ -0,0 +1,37 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "bli_gemmt_front.h" + +#include "bli_gemmt_var.h" diff --git a/frame/3/gemmt/bli_gemmt_front.c b/frame/3/gemmt/bli_gemmt_front.c new file mode 100644 index 000000000..bd2095f53 --- /dev/null +++ b/frame/3/gemmt/bli_gemmt_front.c @@ -0,0 +1,365 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_gemmt_front + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_3, alpha, a, b, beta, c); + bli_init_once(); + + obj_t a_local; + obj_t b_local; + obj_t c_local; + + + // Check parameters. + if ( bli_error_checking_is_enabled() ) + bli_gemm_check( alpha, a, b, beta, c, cntx ); + + // Alias A, B, and C in case we need to apply transformations. + bli_obj_alias_to( a, &a_local ); + bli_obj_alias_to( b, &b_local ); + bli_obj_alias_to( c, &c_local ); + +#ifdef BLIS_ENABLE_GEMM_MD + cntx_t cntx_local; + + // If any of the storage datatypes differ, or if the computation precision + // differs from the storage precision of C, utilize the mixed datatype + // code path. + // NOTE: If we ever want to support the caller setting the computation + // domain explicitly, we will need to check the computation dt against the + // storage dt of C (instead of the computation precision against the + // storage precision of C). + if ( bli_obj_dt( &c_local ) != bli_obj_dt( &a_local ) || + bli_obj_dt( &c_local ) != bli_obj_dt( &b_local ) || + bli_obj_comp_prec( &c_local ) != bli_obj_prec( &c_local ) ) + { + // Handle mixed datatype cases in bli_gemm_md(), which may modify + // the objects or the context. (If the context is modified, cntx + // is adjusted to point to cntx_local.) + bli_gemm_md( &a_local, &b_local, beta, &c_local, &cntx_local, &cntx ); + } + //else // homogeneous datatypes +#endif + + // Load the pack schemas from the context and embed them into the objects + // for A and B. (Native contexts are initialized with the correct pack + // schemas, as are contexts for 1m, and if necessary bli_gemm_md() would + // have made a copy and modified the schemas, so reading them from the + // context should be a safe bet at this point.) This is a sort of hack for + // communicating the desired pack schemas to bli_gemm_cntl_create() (via + // bli_l3_thread_decorator() and bli_l3_cntl_create_if()). This allows us + // to subsequently access the schemas from the control tree, which + // hopefully reduces some confusion, particularly in bli_packm_init(). + const pack_t schema_a = bli_cntx_schema_a_block( cntx ); + const pack_t schema_b = bli_cntx_schema_b_panel( cntx ); + + bli_obj_set_pack_schema( schema_a, &a_local ); + bli_obj_set_pack_schema( schema_b, &b_local ); + + // Next, we handle the possibility of needing to typecast alpha to the + // computation datatype and/or beta to the storage datatype of C. + + // Attach alpha to B, and in the process typecast alpha to the target + // datatype of the matrix (which in this case is equal to the computation + // datatype). + bli_obj_scalar_attach( BLIS_NO_CONJUGATE, alpha, &b_local ); + + // Attach beta to C, and in the process typecast beta to the target + // datatype of the matrix (which in this case is equal to the storage + // datatype of C). + bli_obj_scalar_attach( BLIS_NO_CONJUGATE, beta, &c_local ); + + // Change the alpha and beta pointers to BLIS_ONE since the values have + // now been typecast and attached to the matrices above. + alpha = &BLIS_ONE; + beta = &BLIS_ONE; + +#ifdef BLIS_ENABLE_GEMM_MD + // Don't perform the following optimization for ccr or crc cases, as + // those cases are sensitive to the ukernel storage preference (ie: + // transposing the operation would break them). + if ( !bli_gemm_md_is_ccr( &a_local, &b_local, &c_local ) && + !bli_gemm_md_is_crc( &a_local, &b_local, &c_local ) ) +#endif + // An optimization: If C is stored by rows and the micro-kernel prefers + // contiguous columns, or if C is stored by columns and the micro-kernel + // prefers contiguous rows, transpose the entire operation to allow the + // micro-kernel to access elements of C in its preferred manner. + if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) + { + bli_obj_swap( &a_local, &b_local ); + + bli_obj_induce_trans( &a_local ); + bli_obj_induce_trans( &b_local ); + bli_obj_induce_trans( &c_local ); + + // We must also swap the pack schemas, which were set by bli_gemm_md() + // or the inlined code above. + bli_obj_swap_pack_schemas( &a_local, &b_local ); + } + + // Parse and interpret the contents of the rntm_t object to properly + // set the ways of parallelism for each loop, and then make any + // additional modifications necessary for the current operation. + bli_rntm_set_ways_for_op + ( + BLIS_GEMMT, + BLIS_LEFT, // ignored for gemm/hemm/symm + bli_obj_length( &c_local ), + bli_obj_width( &c_local ), + bli_obj_width( &a_local ), + rntm + ); + + obj_t* cp = &c_local; + obj_t* betap = beta; + +#ifdef BLIS_ENABLE_GEMM_MD +#ifdef BLIS_ENABLE_GEMM_MD_EXTRA_MEM + // If any of the following conditions are met, create a temporary matrix + // conformal to C into which we will accumulate the matrix product: + // - the storage precision of C differs from the computation precision; + // - the domains are mixed as crr; + // - the storage format of C does not match the preferred orientation + // of the ccr or crc cases. + // Then, after the computation is complete, this matrix will be copied + // or accumulated back to C. + const bool_t is_ccr_mismatch = + ( bli_gemm_md_is_ccr( &a_local, &b_local, &c_local ) && + !bli_obj_is_col_stored( &c_local ) ); + const bool_t is_crc_mismatch = + ( bli_gemm_md_is_crc( &a_local, &b_local, &c_local ) && + !bli_obj_is_row_stored( &c_local ) ); + + obj_t ct; + bool_t use_ct = FALSE; + + // FGVZ: Consider adding another guard here that only creates and uses a + // temporary matrix for accumulation if k < c * kc, where c is some small + // constant like 2. And don't forget to use the same conditional for the + // castm() and free() at the end. + if ( + bli_obj_prec( &c_local ) != bli_obj_comp_prec( &c_local ) || + bli_gemm_md_is_crr( &a_local, &b_local, &c_local ) || + is_ccr_mismatch || + is_crc_mismatch + ) + { + use_ct = TRUE; + } + + // If we need a temporary matrix conformal to C for whatever reason, + // we create it and prepare to use it now. + if ( use_ct ) + { + const dim_t m = bli_obj_length( &c_local ); + const dim_t n = bli_obj_width( &c_local ); + inc_t rs = bli_obj_row_stride( &c_local ); + inc_t cs = bli_obj_col_stride( &c_local ); + + num_t dt_ct = bli_obj_domain( &c_local ) | + bli_obj_comp_prec( &c_local ); + + // When performing the crr case, accumulate to a contiguously-stored + // real matrix so we do not have to repeatedly update C with general + // stride. + if ( bli_gemm_md_is_crr( &a_local, &b_local, &c_local ) ) + dt_ct = BLIS_REAL | bli_obj_comp_prec( &c_local ); + + // When performing the mismatched ccr or crc cases, now is the time + // to specify the appropriate storage so the gemm_md_c2r_ref() virtual + // microkernel can output directly to C (instead of using a temporary + // microtile). + if ( is_ccr_mismatch ) { rs = 1; cs = m; } + else if ( is_crc_mismatch ) { rs = n; cs = 1; } + + bli_obj_create( dt_ct, m, n, rs, cs, &ct ); + + const num_t dt_exec = bli_obj_exec_dt( &c_local ); + const num_t dt_comp = bli_obj_comp_dt( &c_local ); + + bli_obj_set_target_dt( dt_ct, &ct ); + bli_obj_set_exec_dt( dt_exec, &ct ); + bli_obj_set_comp_dt( dt_comp, &ct ); + + // A naive approach would cast C to the comptuation datatype, + // compute with beta, and then cast the result back to the + // user-provided output matrix. However, we employ a different + // approach that halves the number of memops on C (or its + // typecast temporary) by writing the A*B product directly to + // temporary storage, and then using xpbym to scale the + // output matrix by beta and accumulate/cast the A*B product. + //bli_castm( &c_local, &ct ); + betap = &BLIS_ZERO; + + cp = &ct; + } +#endif +#endif + + // Invoke the internal back-end via the thread handler. + bli_l3_thread_decorator + ( + bli_gemm_int, + BLIS_GEMMT, // operation family id + alpha, + &a_local, + &b_local, + betap, + cp, + cntx, + rntm, + cntl + ); + +#ifdef BLIS_ENABLE_GEMM_MD +#ifdef BLIS_ENABLE_GEMM_MD_EXTRA_MEM + // If we created a temporary matrix conformal to C for whatever reason, + // we copy/accumulate the result back to C and then release the object. + if ( use_ct ) + { + obj_t beta_local; + + bli_obj_scalar_detach( &c_local, &beta_local ); + + //bli_castnzm( &ct, &c_local ); + bli_xpbym( &ct, &beta_local, &c_local ); + + bli_obj_free( &ct ); + } +#endif +#endif + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); +} + +// ----------------------------------------------------------------------------- + +#if 0 + if ( bli_obj_dt( a ) != bli_obj_dt( b ) || + bli_obj_dt( a ) != bli_obj_dt( c ) || + bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) + { + const bool_t a_is_real = bli_obj_is_real( a ); + const bool_t a_is_comp = bli_obj_is_complex( a ); + const bool_t b_is_real = bli_obj_is_real( b ); + const bool_t b_is_comp = bli_obj_is_complex( b ); + const bool_t c_is_real = bli_obj_is_real( c ); + const bool_t c_is_comp = bli_obj_is_complex( c ); + + const bool_t a_is_single = bli_obj_is_single_prec( a ); + const bool_t a_is_double = bli_obj_is_double_prec( a ); + const bool_t b_is_single = bli_obj_is_single_prec( b ); + const bool_t b_is_double = bli_obj_is_double_prec( b ); + const bool_t c_is_single = bli_obj_is_single_prec( c ); + const bool_t c_is_double = bli_obj_is_double_prec( c ); + + const bool_t comp_single = bli_obj_comp_prec( c ) == BLIS_SINGLE_PREC; + const bool_t comp_double = bli_obj_comp_prec( c ) == BLIS_DOUBLE_PREC; + + const bool_t mixeddomain = bli_obj_domain( c ) != bli_obj_domain( a ) || + bli_obj_domain( c ) != bli_obj_domain( b ); + + ( void )a_is_real; ( void )a_is_comp; + ( void )b_is_real; ( void )b_is_comp; + ( void )c_is_real; ( void )c_is_comp; + ( void )a_is_single; ( void )a_is_double; + ( void )b_is_single; ( void )b_is_double; + ( void )c_is_single; ( void )c_is_double; + ( void )comp_single; ( void )comp_double; + + if ( + //( c_is_comp && a_is_comp && b_is_real ) || + //( c_is_comp && a_is_real && b_is_comp ) || + //( c_is_real && a_is_comp && b_is_comp ) || + //( c_is_comp && a_is_real && b_is_real ) || + //( c_is_real && a_is_comp && b_is_real ) || + //( c_is_real && a_is_real && b_is_comp ) || + //FALSE + TRUE + ) + { + if ( + ( c_is_single && a_is_single && b_is_single && mixeddomain ) || + ( c_is_single && a_is_single && b_is_single && comp_single ) || + ( c_is_single && a_is_single && b_is_single && comp_double ) || + ( c_is_single && a_is_single && b_is_double ) || + ( c_is_single && a_is_double && b_is_single ) || + ( c_is_double && a_is_single && b_is_single ) || + ( c_is_single && a_is_double && b_is_double ) || + ( c_is_double && a_is_single && b_is_double ) || + ( c_is_double && a_is_double && b_is_single ) || + ( c_is_double && a_is_double && b_is_double && comp_single ) || + ( c_is_double && a_is_double && b_is_double && comp_double ) || + ( c_is_double && a_is_double && b_is_double && mixeddomain ) || + FALSE + ) + bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); + else + bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); + } + else + bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); + return; + } +#else +#if 0 + // If any of the storage datatypes differ, or if the execution precision + // differs from the storage precision of C, utilize the mixed datatype + // code path. + // NOTE: We could check the exec dt against the storage dt of C, but for + // now we don't support the caller setting the execution domain + // explicitly. + if ( bli_obj_dt( a ) != bli_obj_dt( b ) || + bli_obj_dt( a ) != bli_obj_dt( c ) || + bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) + { + bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); + return; + } +#endif +#endif diff --git a/frame/3/gemmt/bli_gemmt_front.h b/frame/3/gemmt/bli_gemmt_front.h new file mode 100644 index 000000000..b97501643 --- /dev/null +++ b/frame/3/gemmt/bli_gemmt_front.h @@ -0,0 +1,47 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +void bli_gemmt_front + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl + ); + + diff --git a/frame/3/gemmt/bli_gemmt_ker_var2.c b/frame/3/gemmt/bli_gemmt_ker_var2.c new file mode 100644 index 000000000..e1cb64993 --- /dev/null +++ b/frame/3/gemmt/bli_gemmt_ker_var2.c @@ -0,0 +1,821 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemmt_fp + +typedef void (*FUNCPTR_T) + ( + pack_t schema_a, + pack_t schema_b, + dim_t m_off, + dim_t n_off, + dim_t m, + dim_t n, + dim_t k, + void* alpha, + void* a, inc_t cs_a, inc_t is_a, + dim_t pd_a, inc_t ps_a, + void* b, inc_t rs_b, inc_t is_b, + dim_t pd_b, inc_t ps_b, + void* beta, + void* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ); + +static FUNCPTR_T GENARRAY_T(ftypes); + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname) \ +\ +void PASTEMAC(ch, varname) \ + ( \ + dim_t m_off, \ + dim_t n_off, \ + dim_t m_cur, \ + dim_t n_cur, \ + ctype* ct, inc_t rs_ct, inc_t cs_ct, \ + ctype* beta_cast, \ + ctype* c, inc_t rs_c, inc_t cs_c \ + ) \ +{ \ + dim_t start, end; \ + dim_t m, n, diag; \ +\ + double beta_val = *beta_cast; \ +\ + start = ((n_off < m_off) && (m_off < n_off + n_cur)) ? m_off: n_off; \ + end = ((n_off < m_off + m_cur) && (m_off + m_cur < n_off + n_cur))? (m_off + m_cur):(n_off + n_cur); \ +\ + if( beta_val != 0.0 ) \ + { \ + for(diag = start, m= start-m_off; diag < end; diag++, m++) \ + for(n = 0; n <= diag-n_off; n++) \ + c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \ +\ + for(; m < m_cur; m++) \ + for(n = 0; n < n_cur; n++) \ + c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \ + } \ + else \ + { \ + for(diag = start, m= start-m_off; diag < end; diag++, m++) \ + for(n = 0; n <= diag-n_off; n++) \ + c[m*rs_c + n] = ct[m*rs_ct + n]; \ +\ + for(; m < m_cur; m++) \ + for(n = 0; n < n_cur; n++) \ + c[m*rs_c + n] = ct[m*rs_ct + n]; \ + } \ +\ + return; \ +} +INSERT_GENTFUNC_BASIC0_SD( update_lower_triang ) + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch, varname) \ + ( \ + dim_t m_off, \ + dim_t n_off, \ + dim_t m_cur, \ + dim_t n_cur, \ + ctype* ct, inc_t rs_ct, inc_t cs_ct, \ + ctype* beta_cast, \ + ctype* c, inc_t rs_c, inc_t cs_c \ + ) \ +{ \ + dim_t start, end; \ + dim_t m, n, diag; \ +\ + ctype beta_val = *beta_cast; \ +\ + start = ((n_off < m_off) && (m_off < n_off + n_cur)) ? m_off: n_off; \ + end = ((n_off < m_off + m_cur) && (m_off + m_cur < n_off + n_cur))? (m_off + m_cur):(n_off + n_cur); \ +\ + if( beta_val != 0.0 ) \ + { \ + for(m = 0; m < start-m_off; m++) \ + for(n = 0; n < n_cur; n++) \ + c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \ +\ + for(diag = start, m= start-m_off; diag < end; diag++, m++) \ + for(n = diag-n_off; n < n_cur; n++) \ + c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \ + } \ + else \ + { \ + for(m = 0; m < start-m_off; m++) \ + for(n = 0; n < n_cur; n++) \ + c[m*rs_c + n] = ct[m*rs_ct + n]; \ +\ + for(diag = start, m= start-m_off; diag < end; diag++, m++) \ + for(n = diag-n_off; n < n_cur; n++) \ + c[m*rs_c + n] = ct[m*rs_ct + n]; \ + } \ +\ + return; \ +} +INSERT_GENTFUNC_BASIC0_SD( update_upper_triang ) + +void bli_gemmt_ker_var2 + ( + obj_t* a, + obj_t* b, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_6); + +#ifdef BLIS_ENABLE_GEMM_MD + // By now, A and B have been packed and cast to the execution precision. + // In most cases, such as when storage precision of C differs from the + // execution precision, we utilize the mixed datatype code path. However, + // a few cases still fall within this kernel, such as mixed domain with + // equal precision (ccr, crc, rcc), hence those expressions being disabled + // in the conditional below. + if ( //( bli_obj_domain( c ) != bli_obj_domain( a ) ) || + //( bli_obj_domain( c ) != bli_obj_domain( b ) ) || + ( bli_obj_dt( c ) != bli_obj_exec_dt( c ) ) ) + { + bli_gemm_ker_var2_md( a, b, c, cntx, rntm, cntl, thread ); + return; + } +#endif + + num_t dt_exec = bli_obj_exec_dt( c ); + + pack_t schema_a = bli_obj_pack_schema( a ); + pack_t schema_b = bli_obj_pack_schema( b ); + + dim_t m_off = bli_obj_row_off( c ); + dim_t n_off = bli_obj_col_off( c ); + dim_t m = bli_obj_length( c ); + dim_t n = bli_obj_width( c ); + dim_t k = bli_obj_width( a ); + + void* buf_a = bli_obj_buffer_at_off( a ); + inc_t cs_a = bli_obj_col_stride( a ); + inc_t is_a = bli_obj_imag_stride( a ); + dim_t pd_a = bli_obj_panel_dim( a ); + inc_t ps_a = bli_obj_panel_stride( a ); + + void* buf_b = bli_obj_buffer_at_off( b ); + inc_t rs_b = bli_obj_row_stride( b ); + inc_t is_b = bli_obj_imag_stride( b ); + dim_t pd_b = bli_obj_panel_dim( b ); + inc_t ps_b = bli_obj_panel_stride( b ); + + void* buf_c = bli_obj_buffer_at_off( c ); + inc_t rs_c = bli_obj_row_stride( c ); + inc_t cs_c = bli_obj_col_stride( c ); + + obj_t scalar_a; + obj_t scalar_b; + + void* buf_alpha; + void* buf_beta; + + FUNCPTR_T f; + + bool_t uploc; + if ( bli_obj_is_lower( c ) ) + { + uploc = 0; + } + else + { + uploc = 1; + } + + // Detach and multiply the scalars attached to A and B. + bli_obj_scalar_detach( a, &scalar_a ); + bli_obj_scalar_detach( b, &scalar_b ); + bli_mulsc( &scalar_a, &scalar_b ); + + // Grab the addresses of the internal scalar buffers for the scalar + // merged above and the scalar attached to C. + buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); + buf_beta = bli_obj_internal_scalar_buffer( c ); + + // If 1m is being employed on a column- or row-stored matrix with a + // real-valued beta, we can use the real domain macro-kernel, which + // eliminates a little overhead associated with the 1m virtual + // micro-kernel. +#if 1 + if ( bli_cntx_method( cntx ) == BLIS_1M ) + { + bli_gemm_ind_recast_1m_params + ( + &dt_exec, + schema_a, + c, + &m, &n, &k, + &pd_a, &ps_a, + &pd_b, &ps_b, + &rs_c, &cs_c + ); + } +#endif + +#ifdef BLIS_ENABLE_GEMM_MD + // Tweak parameters in select mixed domain cases (rcc, crc, ccr). + bli_gemm_md_ker_var2_recast + ( + &dt_exec, + bli_obj_dt( a ), + bli_obj_dt( b ), + bli_obj_dt( c ), + &m, &n, &k, + &pd_a, &ps_a, + &pd_b, &ps_b, + c, + &rs_c, &cs_c + ); +#endif + + // Index into the type combination array to extract the correct + // function pointer. + f = ftypes[dt_exec][uploc]; + + // Invoke the function. + f( schema_a, + schema_b, + m_off, + n_off, + m, + n, + k, + buf_alpha, + buf_a, cs_a, is_a, + pd_a, ps_a, + buf_b, rs_b, is_b, + pd_b, ps_b, + buf_beta, + buf_c, rs_c, cs_c, + cntx, + rntm, + thread ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_6); +} + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, uplo, varname ) \ +\ +void PASTEMACT(ch,opname,uplo,varname) \ + ( \ + pack_t schema_a, \ + pack_t schema_b, \ + dim_t m_off, \ + dim_t n_off, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* alpha, \ + void* a, inc_t cs_a, inc_t is_a, \ + dim_t pd_a, inc_t ps_a, \ + void* b, inc_t rs_b, inc_t is_b, \ + dim_t pd_b, inc_t ps_b, \ + void* beta, \ + void* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm, \ + thrinfo_t* thread \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_6); \ + const num_t dt = PASTEMAC(ch,type); \ +\ + if(bli_gemmt_is_strictly_above_diag(m_off, n_off, m, n)) return; \ + /* Alias some constants to simpler names. */ \ + const dim_t MR = pd_a; \ + const dim_t NR = pd_b; \ + /*const dim_t PACKMR = cs_a;*/ \ + /*const dim_t PACKNR = rs_b;*/ \ +\ + /* Query the context for the micro-kernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +\ + ctype* restrict zero = PASTEMAC(ch,0); \ + ctype* restrict a_cast = a; \ + ctype* restrict b_cast = b; \ + ctype* restrict c_cast = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ + ctype* restrict b1; \ + ctype* restrict c1; \ +\ + dim_t m_iter, m_left; \ + dim_t n_iter, n_left; \ + dim_t i, j; \ + dim_t m_cur; \ + dim_t n_cur; \ + inc_t rstep_a; \ + inc_t cstep_b; \ + inc_t rstep_c, cstep_c; \ + auxinfo_t aux; \ +\ + /* + Assumptions/assertions: + rs_a == 1 + cs_a == PACKMR + pd_a == MR + ps_a == stride to next micro-panel of A + rs_b == PACKNR + cs_b == 1 + pd_b == NR + ps_b == stride to next micro-panel of B + rs_c == (no assumptions) + cs_c == (no assumptions) + */ \ +\ + /* If any dimension is zero, return immediately. */ \ + if ( bli_zero_dim3( m, n, k ) ) return; \ +\ + /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ + PASTEMAC(ch,set0s_mxn)( MR, NR, \ + ct, rs_ct, cs_ct ); \ +\ + /* Compute number of primary and leftover components of the m and n + dimensions. */ \ + n_iter = n / NR; \ + n_left = n % NR; \ +\ + m_iter = m / MR; \ + m_left = m % MR; \ +\ + if ( n_left ) ++n_iter; \ + if ( m_left ) ++m_iter; \ +\ + /* Determine some increments used to step through A, B, and C. */ \ + rstep_a = ps_a; \ +\ + cstep_b = ps_b; \ +\ + rstep_c = rs_c * MR; \ + cstep_c = cs_c * NR; \ +\ + /* Save the pack schemas of A and B to the auxinfo_t object. */ \ + bli_auxinfo_set_schema_a( schema_a, &aux ); \ + bli_auxinfo_set_schema_b( schema_b, &aux ); \ +\ + /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ + bli_auxinfo_set_is_a( is_a, &aux ); \ + bli_auxinfo_set_is_b( is_b, &aux ); \ +\ + /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) + loop around the microkernel. Here we query the thrinfo_t node for the + 1st (ir) loop around the microkernel. */ \ + thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ +\ + /* Query the number of threads and thread ids for each loop. */ \ + dim_t jr_nt = bli_thread_n_way( thread ); \ + dim_t jr_tid = bli_thread_work_id( thread ); \ + dim_t ir_nt = bli_thread_n_way( caucus ); \ + dim_t ir_tid = bli_thread_work_id( caucus ); \ +\ + dim_t jr_start, jr_end; \ + dim_t ir_start, ir_end; \ + dim_t jr_inc, ir_inc; \ +\ + dim_t m_off_cblock, n_off_cblock; \ +\ + /* Determine the thread range and increment for the 2nd and 1st loops. + NOTE: The definition of bli_thread_range_jrir() will depend on whether + slab or round-robin partitioning was requested at configure-time. */ \ + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ + bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( j = jr_start; j < jr_end; j += jr_inc ) \ + { \ + ctype* restrict a1; \ + ctype* restrict c11; \ + ctype* restrict b2; \ +\ + b1 = b_cast + j * cstep_b; \ + c1 = c_cast + j * cstep_c; \ +\ + n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ +\ + /* Initialize our next panel of B to be the current panel of B. */ \ + b2 = b1; \ +\ + /* Loop over the m dimension (MR rows at a time). */ \ + for ( i = ir_start; i < ir_end; i += ir_inc ) \ + { \ + ctype* restrict a2; \ +\ + a1 = a_cast + i * rstep_a; \ + c11 = c1 + i * rstep_c; \ +\ + m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ +\ + /* Compute the addresses of the next panels of A and B. */ \ + a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); \ + if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \ + { \ + a2 = a_cast; \ + b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); \ + if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \ + b2 = b_cast; \ + } \ +\ + /* Save addresses of next panels of A and B to the auxinfo_t + object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + m_off_cblock = m_off + i * MR; \ + n_off_cblock = n_off + j * NR; \ +\ + if(!bli_gemmt_is_strictly_above_diag(m_off_cblock, n_off_cblock, m_cur, n_cur)) \ + { \ + if(bli_gemmt_is_strictly_below_diag(m_off_cblock, n_off_cblock, m_cur, n_cur)) \ + { \ + /* Handle interior and edge cases separately. */ \ + if ( m_cur == MR && n_cur == NR ) \ + { \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + else \ + { \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + zero, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ +\ + /* Scale the bottom edge of C and add the result from above. */ \ + PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ + ct, rs_ct, cs_ct, \ + beta_cast, \ + c11, rs_c, cs_c ); \ + } \ + } \ + else \ + { \ + gemm_ukr \ + ( \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + zero, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ +\ + /* Scale the bottom edge of C and add the result from above. */ \ + PASTEMAC(ch,update_lower_triang)( m_off_cblock, n_off_cblock, \ + m_cur, n_cur, \ + ct, rs_ct, cs_ct, \ + beta_cast, \ + c11, rs_c, cs_c ); \ + } \ + } \ + } \ + } \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: c after", m_cur, n_cur, c11, rs_c, cs_c, "%4.1f", "" ); \ +*/ \ +AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_6); \ +} + +INSERT_GENTFUNC_L_SD( gemmt, ker_var2 ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname, uplo, varname ) \ +\ +void PASTEMACT(ch,opname,uplo,varname) \ + ( \ + pack_t schema_a, \ + pack_t schema_b, \ + dim_t m_off, \ + dim_t n_off, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* alpha, \ + void* a, inc_t cs_a, inc_t is_a, \ + dim_t pd_a, inc_t ps_a, \ + void* b, inc_t rs_b, inc_t is_b, \ + dim_t pd_b, inc_t ps_b, \ + void* beta, \ + void* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm, \ + thrinfo_t* thread \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_6); \ + const num_t dt = PASTEMAC(ch,type); \ +\ + if(bli_gemmt_is_strictly_below_diag(m_off, n_off, m, n)) return; \ + /* Alias some constants to simpler names. */ \ + const dim_t MR = pd_a; \ + const dim_t NR = pd_b; \ + /*const dim_t PACKMR = cs_a;*/ \ + /*const dim_t PACKNR = rs_b;*/ \ +\ + /* Query the context for the micro-kernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool_t col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +\ + ctype* restrict zero = PASTEMAC(ch,0); \ + ctype* restrict a_cast = a; \ + ctype* restrict b_cast = b; \ + ctype* restrict c_cast = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ + ctype* restrict b1; \ + ctype* restrict c1; \ +\ + dim_t m_iter, m_left; \ + dim_t n_iter, n_left; \ + dim_t i, j; \ + dim_t m_cur; \ + dim_t n_cur; \ + inc_t rstep_a; \ + inc_t cstep_b; \ + inc_t rstep_c, cstep_c; \ + auxinfo_t aux; \ +\ + /* + Assumptions/assertions: + rs_a == 1 + cs_a == PACKMR + pd_a == MR + ps_a == stride to next micro-panel of A + rs_b == PACKNR + cs_b == 1 + pd_b == NR + ps_b == stride to next micro-panel of B + rs_c == (no assumptions) + cs_c == (no assumptions) + */ \ +\ + /* If any dimension is zero, return immediately. */ \ + if ( bli_zero_dim3( m, n, k ) ) return; \ +\ + /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ + PASTEMAC(ch,set0s_mxn)( MR, NR, \ + ct, rs_ct, cs_ct ); \ +\ + /* Compute number of primary and leftover components of the m and n + dimensions. */ \ + n_iter = n / NR; \ + n_left = n % NR; \ +\ + m_iter = m / MR; \ + m_left = m % MR; \ +\ + if ( n_left ) ++n_iter; \ + if ( m_left ) ++m_iter; \ +\ + /* Determine some increments used to step through A, B, and C. */ \ + rstep_a = ps_a; \ +\ + cstep_b = ps_b; \ +\ + rstep_c = rs_c * MR; \ + cstep_c = cs_c * NR; \ +\ + /* Save the pack schemas of A and B to the auxinfo_t object. */ \ + bli_auxinfo_set_schema_a( schema_a, &aux ); \ + bli_auxinfo_set_schema_b( schema_b, &aux ); \ +\ + /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ + bli_auxinfo_set_is_a( is_a, &aux ); \ + bli_auxinfo_set_is_b( is_b, &aux ); \ +\ + /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) + loop around the microkernel. Here we query the thrinfo_t node for the + 1st (ir) loop around the microkernel. */ \ + thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ +\ + /* Query the number of threads and thread ids for each loop. */ \ + dim_t jr_nt = bli_thread_n_way( thread ); \ + dim_t jr_tid = bli_thread_work_id( thread ); \ + dim_t ir_nt = bli_thread_n_way( caucus ); \ + dim_t ir_tid = bli_thread_work_id( caucus ); \ +\ + dim_t jr_start, jr_end; \ + dim_t ir_start, ir_end; \ + dim_t jr_inc, ir_inc; \ +\ + dim_t m_off_cblock, n_off_cblock; \ +\ + /* Determine the thread range and increment for the 2nd and 1st loops. + NOTE: The definition of bli_thread_range_jrir() will depend on whether + slab or round-robin partitioning was requested at configure-time. */ \ + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ + bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( j = jr_start; j < jr_end; j += jr_inc ) \ + { \ + ctype* restrict a1; \ + ctype* restrict c11; \ + ctype* restrict b2; \ +\ + b1 = b_cast + j * cstep_b; \ + c1 = c_cast + j * cstep_c; \ +\ + n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ +\ + /* Initialize our next panel of B to be the current panel of B. */ \ + b2 = b1; \ +\ + /* Loop over the m dimension (MR rows at a time). */ \ + for ( i = ir_start; i < ir_end; i += ir_inc ) \ + { \ + ctype* restrict a2; \ +\ + a1 = a_cast + i * rstep_a; \ + c11 = c1 + i * rstep_c; \ +\ + m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ +\ + /* Compute the addresses of the next panels of A and B. */ \ + a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); \ + if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \ + { \ + a2 = a_cast; \ + b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); \ + if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \ + b2 = b_cast; \ + } \ +\ + /* Save addresses of next panels of A and B to the auxinfo_t + object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + m_off_cblock = m_off + i * MR; \ + n_off_cblock = n_off + j * NR; \ +\ + if(!bli_gemmt_is_strictly_below_diag(m_off_cblock, n_off_cblock, m_cur, n_cur)) \ + { \ + if(bli_gemmt_is_strictly_above_diag(m_off_cblock, n_off_cblock, m_cur, n_cur)) \ + { \ + /* Handle interior and edge cases separately. */ \ + if ( m_cur == MR && n_cur == NR ) \ + { \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + else \ + { \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + zero, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ +\ + /* Scale the bottom edge of C and add the result from above. */ \ + PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ + ct, rs_ct, cs_ct, \ + beta_cast, \ + c11, rs_c, cs_c ); \ + } \ + } \ + else \ + { \ + gemm_ukr \ + ( \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + zero, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ +\ + /* Scale the bottom edge of C and add the result from above. */ \ + PASTEMAC(ch,update_upper_triang)( m_off_cblock, n_off_cblock, \ + m_cur, n_cur, \ + ct, rs_ct, cs_ct, \ + beta_cast, \ + c11, rs_c, cs_c ); \ + } \ + } \ + } \ + } \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: c after", m_cur, n_cur, c11, rs_c, cs_c, "%4.1f", "" ); \ +*/ \ +AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_6); \ +} + +INSERT_GENTFUNC_U_SD( gemmt, ker_var2 ) + diff --git a/frame/3/gemmt/bli_gemmt_var.h b/frame/3/gemmt/bli_gemmt_var.h new file mode 100644 index 000000000..00e258d32 --- /dev/null +++ b/frame/3/gemmt/bli_gemmt_var.h @@ -0,0 +1,81 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#undef GENPROT +#define GENPROT( opname ) \ +\ +void PASTEMAC0(opname) \ + ( \ + obj_t* a, \ + obj_t* b, \ + obj_t* c, \ + cntx_t* cntx, \ + rntm_t* rntm, \ + cntl_t* cntl, \ + thrinfo_t* thread \ + ); + +GENPROT( gemmt_ker_var2 ) + + + +// +// Prototype BLAS-like interfaces with void pointer operands. +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname, uplo, varname ) \ +\ +void PASTEMACT(ch,opname,uplo,varname) \ + ( \ + pack_t schema_a, \ + pack_t schema_b, \ + dim_t m_off, \ + dim_t n_off, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* alpha, \ + void* a, inc_t cs_a, inc_t is_a, \ + dim_t pd_a, inc_t ps_a, \ + void* b, inc_t rs_b, inc_t is_b, \ + dim_t pd_b, inc_t ps_b, \ + void* beta, \ + void* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm, \ + thrinfo_t* thread \ + ); + +INSERT_GENTPROT_GEMMT_SD( gemmt, ker_var2 ) diff --git a/frame/compat/bla_gemmt.c b/frame/compat/bla_gemmt.c new file mode 100644 index 000000000..0923a08f7 --- /dev/null +++ b/frame/compat/bla_gemmt.c @@ -0,0 +1,235 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// + +#ifdef BLIS_BLAS3_CALLS_TAPI + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + uplo_t blis_uploc, \ + trans_t blis_transa; \ + trans_t blis_transb; \ + dim_t n0, k0; \ + inc_t rs_a, cs_a; \ + inc_t rs_b, cs_b; \ + inc_t rs_c, cs_c; \ +\ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + uploc, \ + transa, \ + transb, \ + n, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ + bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *n, n0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + rs_a = 1; \ + cs_a = *lda; \ + rs_b = 1; \ + cs_b = *ldb; \ + rs_c = 1; \ + cs_c = *ldc; \ +\ + if(!( n )) \ + return; \ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_uploc, \ + blis_transa, \ + blis_transb, \ + n0, \ + k0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, rs_b, cs_b, \ + (ftype*)beta, \ + (ftype*)c, rs_c, cs_c, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#else + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + trans_t blis_transa; \ + trans_t blis_transb; \ + uplo_t blis_uploc; \ +\ + dim_t n0, k0; \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + uploc, \ + transa, \ + transb, \ + n, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ + bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *n, n0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t m0_a, n0_a; \ + dim_t m0_b, n0_b; \ +\ + bli_set_dims_with_trans( blis_transa, n0, k0, &m0_a, &n0_a ); \ + bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, (ftype*)beta, &betao ); \ +\ + bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, n0, n0, (ftype*)c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_conjtrans( blis_transa, &ao ); \ + bli_obj_set_conjtrans( blis_transb, &bo ); \ + bli_obj_set_uplo( blis_uploc, &co ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + +#ifdef BLIS_ENABLE_BLAS +INSERT_GENTFUNC_BLAS( gemmt, gemmt ) +#endif diff --git a/frame/compat/bla_gemmt.h b/frame/compat/bla_gemmt.h new file mode 100644 index 000000000..8043d6829 --- /dev/null +++ b/frame/compat/bla_gemmt.h @@ -0,0 +1,58 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +// +// Prototype BLAS-to-BLIS interfaces. +// +#undef GENTPROT +#define GENTPROT( ftype, ch, blasname ) \ +\ +BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ); + +#ifdef BLIS_ENABLE_BLAS +INSERT_GENTPROT_BLAS( gemmt ) +#endif diff --git a/frame/compat/bli_blas.h b/frame/compat/bli_blas.h index 24015074b..1ce976453 100644 --- a/frame/compat/bli_blas.h +++ b/frame/compat/bli_blas.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -185,6 +186,7 @@ #include "bla_syr2k.h" #include "bla_trmm.h" #include "bla_trsm.h" +#include "bla_gemmt.h" #include "bla_gemm_check.h" #include "bla_hemm_check.h" @@ -195,6 +197,7 @@ #include "bla_syr2k_check.h" #include "bla_trmm_check.h" #include "bla_trsm_check.h" +#include "bla_gemmt_check.h" // -- Fortran-compatible APIs to BLIS functions -- diff --git a/frame/compat/cblas/src/cblas.h b/frame/compat/cblas/src/cblas.h index 85778c8a4..cbf3e679a 100644 --- a/frame/compat/cblas/src/cblas.h +++ b/frame/compat/cblas/src/cblas.h @@ -448,6 +448,11 @@ void BLIS_EXPORT_BLAS cblas_strsm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_DIAG Diag, f77_int M, f77_int N, float alpha, const float *A, f77_int lda, float *B, f77_int ldb); +void BLIS_EXPORT_BLAS cblas_sgemmt(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, + f77_int N, f77_int K, float alpha, const float *A, + f77_int lda, const float *B, f77_int ldb, + float beta, float *C, f77_int ldc); void BLIS_EXPORT_BLAS cblas_dgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, f77_int M, f77_int N, @@ -478,6 +483,11 @@ void BLIS_EXPORT_BLAS cblas_dtrsm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_DIAG Diag, f77_int M, f77_int N, double alpha, const double *A, f77_int lda, double *B, f77_int ldb); +void BLIS_EXPORT_BLAS cblas_dgemmt(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, + f77_int N, f77_int K, double alpha, const double *A, + f77_int lda, const double *B, f77_int ldb, + double beta, double *C, f77_int ldc); void BLIS_EXPORT_BLAS cblas_cgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, f77_int M, f77_int N, @@ -508,6 +518,11 @@ void BLIS_EXPORT_BLAS cblas_ctrsm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_DIAG Diag, f77_int M, f77_int N, const void *alpha, const void *A, f77_int lda, void *B, f77_int ldb); +void BLIS_EXPORT_BLAS cblas_cgemmt(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, + f77_int N, f77_int K, const void *alpha, const void *A, + f77_int lda, const void *B, f77_int ldb, + const void *beta, void *C, f77_int ldc); void BLIS_EXPORT_BLAS cblas_zgemm(enum CBLAS_ORDER Order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, f77_int M, f77_int N, @@ -538,6 +553,11 @@ void BLIS_EXPORT_BLAS cblas_ztrsm(enum CBLAS_ORDER Order, enum CBLAS_SIDE Side, enum CBLAS_DIAG Diag, f77_int M, f77_int N, const void *alpha, const void *A, f77_int lda, void *B, f77_int ldb); +void BLIS_EXPORT_BLAS cblas_zgemmt(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, + f77_int N, f77_int K, const void *alpha, const void *A, + f77_int lda, const void *B, f77_int ldb, + const void *beta, void *C, f77_int ldc); /* diff --git a/frame/compat/cblas/src/cblas_cgemmt.c b/frame/compat/cblas/src/cblas_cgemmt.c new file mode 100644 index 000000000..c93696c67 --- /dev/null +++ b/frame/compat/cblas/src/cblas_cgemmt.c @@ -0,0 +1,130 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * + * cblas_cgemmt.c + * This program is a C interface to cgemmt. + * Copyright (C) 2020, Advanced Micro Devices, Inc. + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_cgemmt(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA,enum CBLAS_TRANSPOSE TransB, + f77_int N, f77_int K, + const void *alpha, const void *A, + f77_int lda, const void *B, f77_int ldb, + const void *beta, void *C, f77_int ldc) +{ + char TA, TB, UL; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_TB, F77_UL; +#else + #define F77_TA &TA + #define F77_TB &TB + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + extern int CBLAS_CallFromC; + extern int RowMajorStrg; + RowMajorStrg = 0; + CBLAS_CallFromC = 1; + + if( Order == CblasColMajor ) + { + if( Uplo == CblasUpper ) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(2, "cblas_cgemmt","Illegal Uplo setting, %d\n", Uplo); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + if(TransA == CblasTrans) TA='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(3, "cblas_cgemmt","Illegal TransA setting, %d\n", TransA); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransB == CblasTrans) TB='T'; + else if ( TransB == CblasConjTrans ) TB='C'; + else if ( TransB == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(4, "cblas_cgemmt","Illegal TransB setting, %d\n", TransB); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + + F77_cgemmt(F77_UL, F77_TA, F77_TB, &F77_N, &F77_K, (scomplex*)alpha, (scomplex*)A, + &F77_lda, (scomplex*)B, &F77_ldb, (scomplex*)beta, (scomplex*)C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + RowMajorStrg = 1; + if( Uplo == CblasUpper ) UL='U'; + else if( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(2, "cblas_cgemmt","Illegal Uplo setting, %d\n", Uplo); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + if(TransA == CblasTrans) TB='T'; + else if ( TransA == CblasConjTrans ) TB='C'; + else if ( TransA == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(3, "cblas_cgemmt","Illegal TransA setting, %d\n", TransA); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + if(TransB == CblasTrans) TA='T'; + else if ( TransB == CblasConjTrans ) TA='C'; + else if ( TransB == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_cgemmt","Illegal TransB setting, %d\n", TransB); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + + F77_cgemmt(F77_UL, F77_TA, F77_TB, &F77_N, &F77_K, (scomplex*)alpha, (scomplex*)B, + &F77_ldb, (scomplex*)A, &F77_lda, (scomplex*)beta, (scomplex*)C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_cgemmt", "Illegal Order setting, %d\n", Order); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; +} +#endif diff --git a/frame/compat/cblas/src/cblas_dgemmt.c b/frame/compat/cblas/src/cblas_dgemmt.c new file mode 100644 index 000000000..eeb7ffce4 --- /dev/null +++ b/frame/compat/cblas/src/cblas_dgemmt.c @@ -0,0 +1,135 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * + * cblas_dgemmt.c + * This program is a C interface to dgemmt. + * + * Copyright (C) 2020, Advanced Micro Devices, Inc. + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_dgemmt( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, + f77_int N, f77_int K, + double alpha, const double *A, + f77_int lda, const double *B, f77_int ldb, + double beta, double *C, f77_int ldc) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + char TA, TB, UL; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_TB, F77_UL; +#else + #define F77_TA &TA + #define F77_TB &TB + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + extern int CBLAS_CallFromC; + extern int RowMajorStrg; + RowMajorStrg = 0; + CBLAS_CallFromC = 1; + + if( Order == CblasColMajor ) + { + if( Uplo == CblasUpper) UL = 'U'; + else if(Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_dgemmt","Illegal Uplo setting, %d\n", Uplo); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransA == CblasTrans) TA='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(3, "cblas_dgemmt","Illegal TransA setting, %d\n", TransA); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransB == CblasTrans) TB='T'; + else if ( TransB == CblasConjTrans ) TB='C'; + else if ( TransB == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(4, "cblas_dgemmt","Illegal TransB setting, %d\n", TransB); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + F77_UL = C2F_CHAR(&UL); + #endif + + F77_dgemmt(F77_UL,F77_TA, F77_TB, &F77_N, &F77_K, &alpha, A, + &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + RowMajorStrg = 1; + if(Uplo == CblasUpper) UL = 'L'; + else if(Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_dgemmt","Illegal Uplo setting, %d\n", Uplo); + } + + if(TransA == CblasTrans) TB='T'; + else if ( TransA == CblasConjTrans ) TB='C'; + else if ( TransA == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(3, "cblas_dgemmt","Illegal TransA setting, %d\n", TransA); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + if(TransB == CblasTrans) TA='T'; + else if ( TransB == CblasConjTrans ) TA='C'; + else if ( TransB == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_dgemmt","Illegal TransB setting, %d\n", TransB); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + F77_UL = C2F_CHAR(&UL); + #endif + + F77_dgemmt(F77_UL,F77_TA, F77_TB, &F77_N, &F77_K, &alpha, B, + &F77_ldb, A, &F77_lda, &beta, C, &F77_ldc); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + } + else cblas_xerbla(1, "cblas_dgemmt", "Illegal Order setting, %d\n", Order); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; +} +#endif diff --git a/frame/compat/cblas/src/cblas_f77.h b/frame/compat/cblas/src/cblas_f77.h index fcdd946df..b118bc5af 100644 --- a/frame/compat/cblas/src/cblas_f77.h +++ b/frame/compat/cblas/src/cblas_f77.h @@ -6,6 +6,9 @@ * Merged cblas_f77.h and cblas_fortran_header.h * * (Heavily hacked down from the original) + * + * Copyright (C) 2020, Advanced Micro Devices, Inc. + * */ #ifndef CBLAS_F77_H @@ -163,5 +166,9 @@ #define F77_zsyr2k zsyr2k_ #define F77_ztrmm ztrmm_ #define F77_ztrsm ztrsm_ +#define F77_dgemmt dgemmt_ +#define F77_sgemmt sgemmt_ +#define F77_cgemmt cgemmt_ +#define F77_zgemmt zgemmt_ #endif /* CBLAS_F77_H */ diff --git a/frame/compat/cblas/src/cblas_sgemmt.c b/frame/compat/cblas/src/cblas_sgemmt.c new file mode 100644 index 000000000..9f71143e6 --- /dev/null +++ b/frame/compat/cblas/src/cblas_sgemmt.c @@ -0,0 +1,135 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * + * cblas_sgemmt.c + * This program is a C interface to sgemmt. + * + * Copyright (C) 2020, Advanced Micro Devices, Inc. + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_sgemmt( enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, + f77_int N, f77_int K, + float alpha, const float *A, + f77_int lda, const float *B, f77_int ldb, + float beta, float *C, f77_int ldc) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + char TA, TB, UL; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_TB, F77_UL; +#else + #define F77_TA &TA + #define F77_TB &TB + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + extern int CBLAS_CallFromC; + extern int RowMajorStrg; + RowMajorStrg = 0; + CBLAS_CallFromC = 1; + + if( Order == CblasColMajor ) + { + if( Uplo == CblasUpper) UL = 'U'; + else if(Uplo == CblasLower) UL = 'L'; + else + { + cblas_xerbla(2, "cblas_sgemmt","Illegal Uplo setting, %d\n", Uplo); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransA == CblasTrans) TA='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(3, "cblas_sgemmt","Illegal TransA setting, %d\n", TransA); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransB == CblasTrans) TB='T'; + else if ( TransB == CblasConjTrans ) TB='C'; + else if ( TransB == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(4, "cblas_sgemmt","Illegal TransB setting, %d\n", TransB); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + F77_UL = C2F_CHAR(&UL); + #endif + + F77_sgemmt(F77_UL,F77_TA, F77_TB, &F77_N, &F77_K, &alpha, A, + &F77_lda, B, &F77_ldb, &beta, C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + RowMajorStrg = 1; + if(Uplo == CblasUpper) UL = 'L'; + else if(Uplo == CblasLower) UL = 'U'; + else + { + cblas_xerbla(2, "cblas_sgemmt","Illegal Uplo setting, %d\n", Uplo); + } + + if(TransA == CblasTrans) TB='T'; + else if ( TransA == CblasConjTrans ) TB='C'; + else if ( TransA == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(3, "cblas_sgemmt","Illegal TransA setting, %d\n", TransA); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + if(TransB == CblasTrans) TA='T'; + else if ( TransB == CblasConjTrans ) TA='C'; + else if ( TransB == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_sgemmt","Illegal TransB setting, %d\n", TransB); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + F77_UL = C2F_CHAR(&UL); + #endif + + F77_sgemmt(F77_UL,F77_TA, F77_TB, &F77_N, &F77_K, &alpha, B, + &F77_ldb, A, &F77_lda, &beta, C, &F77_ldc); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + } + else cblas_xerbla(1, "cblas_sgemmt", "Illegal Order setting, %d\n", Order); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; +} +#endif diff --git a/frame/compat/cblas/src/cblas_zgemmt.c b/frame/compat/cblas/src/cblas_zgemmt.c new file mode 100644 index 000000000..3465a62ba --- /dev/null +++ b/frame/compat/cblas/src/cblas_zgemmt.c @@ -0,0 +1,131 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_CBLAS +/* + * + * cblas_zgemmt.c + * This program is a C interface to zgemmt. + * + * Copyright (C) 2020, Advanced Micro Devices, Inc. + * + */ + +#include "cblas.h" +#include "cblas_f77.h" +void cblas_zgemmt(enum CBLAS_ORDER Order, enum CBLAS_UPLO Uplo, + enum CBLAS_TRANSPOSE TransA,enum CBLAS_TRANSPOSE TransB, + f77_int N, f77_int K, + const void *alpha, const void *A, + f77_int lda, const void *B, f77_int ldb, + const void *beta, void *C, f77_int ldc) +{ + char TA, TB, UL; +#ifdef F77_CHAR + F77_CHAR F77_TA, F77_TB, F77_UL; +#else + #define F77_TA &TA + #define F77_TB &TB + #define F77_UL &UL +#endif + +#ifdef F77_INT + F77_INT F77_N=N, F77_K=K, F77_lda=lda, F77_ldb=ldb; + F77_INT F77_ldc=ldc; +#else + #define F77_N N + #define F77_K K + #define F77_lda lda + #define F77_ldb ldb + #define F77_ldc ldc +#endif + + extern int CBLAS_CallFromC; + extern int RowMajorStrg; + RowMajorStrg = 0; + CBLAS_CallFromC = 1; + + if( Order == CblasColMajor ) + { + if( Uplo == CblasUpper ) UL='U'; + else if ( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(2, "cblas_zgemmt","Illegal Uplo setting, %d\n", Uplo); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + if(TransA == CblasTrans) TA='T'; + else if ( TransA == CblasConjTrans ) TA='C'; + else if ( TransA == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(3, "cblas_zgemmt","Illegal TransA setting, %d\n", TransA); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + if(TransB == CblasTrans) TB='T'; + else if ( TransB == CblasConjTrans ) TB='C'; + else if ( TransB == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(4, "cblas_zgemmt","Illegal TransB setting, %d\n", TransB); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + + F77_zgemmt(F77_UL, F77_TA, F77_TB, &F77_N, &F77_K, (dcomplex*)alpha, (dcomplex*)A, + &F77_lda, (dcomplex*)B, &F77_ldb, (dcomplex*)beta, (dcomplex*)C, &F77_ldc); + } else if (Order == CblasRowMajor) + { + RowMajorStrg = 1; + if( Uplo == CblasUpper ) UL='U'; + else if( Uplo == CblasLower ) UL='L'; + else + { + cblas_xerbla(2, "cblas_zgemmt","Illegal Uplo setting, %d\n", Uplo); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + if(TransA == CblasTrans) TB='T'; + else if ( TransA == CblasConjTrans ) TB='C'; + else if ( TransA == CblasNoTrans ) TB='N'; + else + { + cblas_xerbla(3, "cblas_zgemmt","Illegal TransA setting, %d\n", TransA); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + if(TransB == CblasTrans) TA='T'; + else if ( TransB == CblasConjTrans ) TA='C'; + else if ( TransB == CblasNoTrans ) TA='N'; + else + { + cblas_xerbla(4, "cblas_zgemmt","Illegal TransB setting, %d\n", TransB); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; + } + #ifdef F77_CHAR + F77_TA = C2F_CHAR(&TA); + F77_TB = C2F_CHAR(&TB); + #endif + + F77_zgemmt(F77_UL, F77_TA, F77_TB, &F77_N, &F77_K, (dcomplex*)alpha, (dcomplex*)B, + &F77_ldb, (dcomplex*)A, &F77_lda, (dcomplex*)beta, (dcomplex*)C, &F77_ldc); + } + else cblas_xerbla(1, "cblas_zgemmt", "Illegal Order setting, %d\n", Order); + CBLAS_CallFromC = 0; + RowMajorStrg = 0; + return; +} +#endif diff --git a/frame/compat/check/bla_gemmt_check.h b/frame/compat/check/bla_gemmt_check.h new file mode 100644 index 000000000..d08ab5558 --- /dev/null +++ b/frame/compat/check/bla_gemmt_check.h @@ -0,0 +1,92 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef BLIS_ENABLE_BLAS + +#define bla_gemmt_check( dt_str, op_str, uploc, transa, transb, n, k, lda, ldb, ldc ) \ +{ \ + f77_int info = 0; \ + f77_int nota, notb; \ + f77_int conja, conjb; \ + f77_int ta, tb; \ + f77_int lower, upper; \ + f77_int nrowa, nrowb; \ +\ + nota = PASTEF770(lsame)( transa, "N", (ftnlen)1, (ftnlen)1 ); \ + notb = PASTEF770(lsame)( transb, "N", (ftnlen)1, (ftnlen)1 ); \ + conja = PASTEF770(lsame)( transa, "C", (ftnlen)1, (ftnlen)1 ); \ + conjb = PASTEF770(lsame)( transb, "C", (ftnlen)1, (ftnlen)1 ); \ + ta = PASTEF770(lsame)( transa, "T", (ftnlen)1, (ftnlen)1 ); \ + tb = PASTEF770(lsame)( transb, "T", (ftnlen)1, (ftnlen)1 ); \ +\ + lower = PASTEF770(lsame)( uploc, "L", (ftnlen)1, (ftnlen)1 ); \ + upper = PASTEF770(lsame)( uploc, "U", (ftnlen)1, (ftnlen)1 ); \ +\ + if ( nota ) { nrowa = *n; } \ + else { nrowa = *k; } \ + if ( notb ) { nrowb = *k; } \ + else { nrowb = *n; } \ +\ + if ( !lower && !upper ) \ + info = 1; \ + else if ( !nota && !conja && !ta ) \ + info = 2; \ + else if ( !notb && !conjb && !tb ) \ + info = 3; \ + else if ( *n < 0 ) \ + info = 4; \ + else if ( *k < 0 ) \ + info = 5; \ + else if ( *lda < bli_max( 1, nrowa ) ) \ + info = 8; \ + else if ( *ldb < bli_max( 1, nrowb ) ) \ + info = 10; \ + else if ( *ldc < bli_max( 1, *n ) ) \ + info = 13; \ +\ + if ( info != 0 ) \ + { \ + char func_str[ BLIS_MAX_BLAS_FUNC_STR_LENGTH ]; \ +\ + sprintf( func_str, "%s%-5s", dt_str, op_str ); \ +\ + bli_string_mkupper( func_str ); \ +\ + PASTEF770(xerbla)( func_str, &info, (ftnlen)6 ); \ +\ + return; \ + } \ +} + +#endif diff --git a/frame/include/bli_genarray_macro_defs.h b/frame/include/bli_genarray_macro_defs.h index 23cee1064..dc8e5e4c0 100644 --- a/frame/include/bli_genarray_macro_defs.h +++ b/frame/include/bli_genarray_macro_defs.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -103,6 +104,18 @@ arrayname[BLIS_NUM_FP_TYPES] = \ PASTEMAC(z,op) \ } + +#define GENARRAY_T(arrayname) \ +\ +arrayname[BLIS_NUM_FP_TYPES][2] = \ +{ \ + {PASTEMACT(s,gemmt,l,ker_var2), PASTEMACT(s,gemmt,u,ker_var2)}, \ + {NULL,NULL}, \ + {PASTEMACT(d,gemmt,l,ker_var2), PASTEMACT(d,gemmt,u,ker_var2)}, \ + {NULL,NULL}, \ +} + + #define GENARRAY_I(arrayname,op) \ \ arrayname[BLIS_NUM_FP_TYPES+1] = \ diff --git a/frame/include/bli_gentfunc_macro_defs.h b/frame/include/bli_gentfunc_macro_defs.h index 39c5f12e3..34376c021 100644 --- a/frame/include/bli_gentfunc_macro_defs.h +++ b/frame/include/bli_gentfunc_macro_defs.h @@ -142,6 +142,16 @@ GENTFUNCSCAL( dcomplex, dcomplex, z, , blasname, blisname ) \ GENTFUNCSCAL( scomplex, float, c, s, blasname, blisname ) \ GENTFUNCSCAL( dcomplex, double, z, d, blasname, blisname ) +// --GEMMT specific kernels ---------------------------------------------------- +#define INSERT_GENTFUNC_L_SD( opname, funcname ) \ +\ +GENTFUNC(float, s, opname, l, funcname) \ +GENTFUNC(double, d, opname, l, funcname) + +#define INSERT_GENTFUNC_U_SD( opname, funcname ) \ +\ +GENTFUNC(float, s, opname, u, funcname) \ +GENTFUNC(double, d, opname, u, funcname) // -- Macros for functions with one operand ------------------------------------ @@ -158,6 +168,12 @@ GENTFUNC( scomplex, c, tfuncname ) \ GENTFUNC( dcomplex, z, tfuncname ) +#define INSERT_GENTFUNC_BASIC0_SD( tfuncname ) \ +\ +GENTFUNC( float, s, tfuncname ) \ +GENTFUNC( double, d, tfuncname ) + + #define INSERT_GENTFUNC_BASIC0_CZ( tfuncname ) \ \ GENTFUNC( scomplex, c, tfuncname ) \ diff --git a/frame/include/bli_gentprot_macro_defs.h b/frame/include/bli_gentprot_macro_defs.h index f6aa70946..83b7986e6 100644 --- a/frame/include/bli_gentprot_macro_defs.h +++ b/frame/include/bli_gentprot_macro_defs.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -121,7 +122,13 @@ GENTPROTSCAL( dcomplex, dcomplex, , z, blasname ) \ GENTPROTSCAL( float, scomplex, s, c, blasname ) \ GENTPROTSCAL( double, dcomplex, d, z, blasname ) - +// -- GEMMT specific function -------------------------------------------------- +#define INSERT_GENTPROT_GEMMT_SD(opname, funcname) \ +\ +GENTPROT( float, s, gemmt, l, funcname ) \ +GENTPROT( double, d, gemmt, l, funcname ) \ +GENTPROT( float, s, gemmt, u, funcname ) \ +GENTPROT( double, d, gemmt, u, funcname ) // -- Macros for functions with one operand ------------------------------------ @@ -138,6 +145,8 @@ GENTPROT( double, d, tfuncname ) \ GENTPROT( scomplex, c, tfuncname ) \ GENTPROT( dcomplex, z, tfuncname ) + + // -- (one auxiliary argument) -- #define INSERT_GENTPROT_BASIC( tfuncname, varname ) \ diff --git a/frame/include/bli_macro_defs.h b/frame/include/bli_macro_defs.h index 907a5a26c..77bc14a37 100644 --- a/frame/include/bli_macro_defs.h +++ b/frame/include/bli_macro_defs.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -155,6 +155,7 @@ #define MKSTR(s1) #s1 #define STRINGIFY_INT( s ) MKSTR( s ) +#define PASTEMACT(ch1, ch2, ch3, ch4) bli_ ## ch1 ## ch2 ## _ ## ch3 ## _ ## ch4 // Fortran-77 name-mangling macros. #define PASTEF770(name) name ## _ #define PASTEF77(ch1,name) ch1 ## name ## _ diff --git a/frame/include/bli_param_macro_defs.h b/frame/include/bli_param_macro_defs.h index cc1737e91..ec576c2cf 100644 --- a/frame/include/bli_param_macro_defs.h +++ b/frame/include/bli_param_macro_defs.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -581,7 +581,19 @@ static bool_t bli_has_nonunit_inc3( inc_t s1, inc_t s2, inc_t s3 ) ( s1 != 1 || s2 != 1 || s3 != 1 ); } +// offset-relate +static bool_t bli_gemmt_is_strictly_below_diag( dim_t m_off, dim_t n_off, dim_t m, dim_t n ) +{ + return ( bool_t ) + ( ( n_off + n - 1 ) < m_off ); +} + +static bool_t bli_gemmt_is_strictly_above_diag( dim_t m_off, dim_t n_off, dim_t m, dim_t n ) +{ + return ( bool_t ) + ( ( m_off + m - 1 ) < n_off ); +} // diag offset-related static void bli_negate_diag_offset( doff_t* diagoff ) diff --git a/frame/include/bli_type_defs.h b/frame/include/bli_type_defs.h index 01476f185..28264c42a 100644 --- a/frame/include/bli_type_defs.h +++ b/frame/include/bli_type_defs.h @@ -926,11 +926,11 @@ typedef enum BLIS_TRMM3, BLIS_TRMM, BLIS_TRSM, - + BLIS_GEMMT, BLIS_NOID } opid_t; -#define BLIS_NUM_LEVEL3_OPS 10 +#define BLIS_NUM_LEVEL3_OPS 11 // -- Blocksize ID type -- diff --git a/frame/ind/bli_l3_ind.c b/frame/ind/bli_l3_ind.c index 02a7668b8..ce7ab2167 100644 --- a/frame/ind/bli_l3_ind.c +++ b/frame/ind/bli_l3_ind.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -37,21 +37,21 @@ static void_fp bli_l3_ind_oper_fp[BLIS_NUM_IND_METHODS][BLIS_NUM_LEVEL3_OPS] = { - /* gemm hemm herk her2k symm syrk, syr2k trmm3 trmm trsm */ + /* gemm hemm herk her2k symm syrk, syr2k trmm3 trmm trsm gemmt*/ /* 3mh */ { bli_gemm3mh, bli_hemm3mh, bli_herk3mh, bli_her2k3mh, bli_symm3mh, - bli_syrk3mh, bli_syr2k3mh, bli_trmm33mh, NULL, NULL }, + bli_syrk3mh, bli_syr2k3mh, bli_trmm33mh, NULL, NULL , NULL }, /* 3m1 */ { bli_gemm3m1, bli_hemm3m1, bli_herk3m1, bli_her2k3m1, bli_symm3m1, - bli_syrk3m1, bli_syr2k3m1, bli_trmm33m1, bli_trmm3m1, bli_trsm3m1 }, + bli_syrk3m1, bli_syr2k3m1, bli_trmm33m1, bli_trmm3m1, bli_trsm3m1 , NULL }, /* 4mh */ { bli_gemm4mh, bli_hemm4mh, bli_herk4mh, bli_her2k4mh, bli_symm4mh, - bli_syrk4mh, bli_syr2k4mh, bli_trmm34mh, NULL, NULL }, + bli_syrk4mh, bli_syr2k4mh, bli_trmm34mh, NULL, NULL , NULL }, /* 4mb */ { bli_gemm4mb, NULL, NULL, NULL, NULL, - NULL, NULL, NULL, NULL, NULL }, + NULL, NULL, NULL, NULL, NULL , NULL }, /* 4m1 */ { bli_gemm4m1, bli_hemm4m1, bli_herk4m1, bli_her2k4m1, bli_symm4m1, - bli_syrk4m1, bli_syr2k4m1, bli_trmm34m1, bli_trmm4m1, bli_trsm4m1 }, + bli_syrk4m1, bli_syr2k4m1, bli_trmm34m1, bli_trmm4m1, bli_trsm4m1 , NULL }, /* 1m */ { bli_gemm1m, bli_hemm1m, bli_herk1m, bli_her2k1m, bli_symm1m, - bli_syrk1m, bli_syr2k1m, bli_trmm31m, bli_trmm1m, bli_trsm1m }, + bli_syrk1m, bli_syr2k1m, bli_trmm31m, bli_trmm1m, bli_trsm1m , NULL }, /* nat */ { bli_gemmnat, bli_hemmnat, bli_herknat, bli_her2knat, bli_symmnat, - bli_syrknat, bli_syr2knat, bli_trmm3nat, bli_trmmnat, bli_trsmnat }, + bli_syrknat, bli_syr2knat, bli_trmm3nat, bli_trmmnat, bli_trsmnat , bli_gemmtnat }, }; // @@ -99,6 +99,7 @@ bool_t PASTEMAC(opname,ind_has_avail)( num_t dt ) */ GENFUNC( gemm, BLIS_GEMM ) +GENFUNC( gemmt, BLIS_GEMMT ) GENFUNC( hemm, BLIS_HEMM ) GENFUNC( herk, BLIS_HERK ) GENFUNC( her2k, BLIS_HER2K ) diff --git a/frame/ind/bli_l3_ind.h b/frame/ind/bli_l3_ind.h index 216eaedcb..ae57bdfb9 100644 --- a/frame/ind/bli_l3_ind.h +++ b/frame/ind/bli_l3_ind.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -44,6 +45,7 @@ void_fp PASTEMAC(opname,ind_get_avail)( num_t dt ); /*bool_t PASTEMAC(opname,ind_has_avail)( num_t dt ); */ GENPROT( gemm ) +GENPROT( gemmt ) GENPROT( hemm ) GENPROT( herk ) GENPROT( her2k ) diff --git a/frame/ind/oapi/bli_l3_ind_oapi.c b/frame/ind/oapi/bli_l3_ind_oapi.c index 3bdcd96ee..931153a2d 100644 --- a/frame/ind/oapi/bli_l3_ind_oapi.c +++ b/frame/ind/oapi/bli_l3_ind_oapi.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -67,6 +67,7 @@ void PASTEMAC(opname,imeth) \ } GENFRONT( gemm, ind ) +GENFRONT( gemmt, ind ) GENFRONT( her2k, ind ) GENFRONT( syr2k, ind ) diff --git a/frame/ind/oapi/bli_l3_ind_oapi.h b/frame/ind/oapi/bli_l3_ind_oapi.h index d4767925d..2d9dc5e10 100644 --- a/frame/ind/oapi/bli_l3_ind_oapi.h +++ b/frame/ind/oapi/bli_l3_ind_oapi.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -49,7 +50,8 @@ BLIS_EXPORT_BLIS void PASTEMAC(syrk,imeth) ( obj_t* alpha, obj_t* a BLIS_EXPORT_BLIS void PASTEMAC(syr2k,imeth)( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ BLIS_EXPORT_BLIS void PASTEMAC(trmm3,imeth)( side_t side, obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ BLIS_EXPORT_BLIS void PASTEMAC(trmm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, cntx_t* cntx, rntm_t* rntm ); \ -BLIS_EXPORT_BLIS void PASTEMAC(trsm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, cntx_t* cntx, rntm_t* rntm ); +BLIS_EXPORT_BLIS void PASTEMAC(trsm,imeth) ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, cntx_t* cntx, rntm_t* rntm ); \ +BLIS_EXPORT_BLIS void PASTEMAC(gemmt,imeth) ( obj_t* alpha, obj_t* a, obj_t* b, obj_t* beta, obj_t* c, cntx_t* cntx, rntm_t* rntm ); \ GENPROT( nat ) GENPROT( ind ) diff --git a/frame/ind/oapi/bli_l3_nat_oapi.c b/frame/ind/oapi/bli_l3_nat_oapi.c index dc170d7ce..a5ee9bab3 100644 --- a/frame/ind/oapi/bli_l3_nat_oapi.c +++ b/frame/ind/oapi/bli_l3_nat_oapi.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -81,6 +81,7 @@ void PASTEMAC(opname,imeth) \ // defined in the sandbox environment. #ifndef BLIS_ENABLE_SANDBOX GENFRONT( gemm, gemm, nat ) +GENFRONT( gemmt, gemm, nat ) #endif GENFRONT( her2k, gemm, nat ) GENFRONT( syr2k, gemm, nat ) diff --git a/test/Makefile b/test/Makefile index fec2c1458..73afeb98a 100644 --- a/test/Makefile +++ b/test/Makefile @@ -180,6 +180,7 @@ blis: \ test_scalv_blis.x \ \ test_gemm_blis.x \ + test_gemmt_blis.x \ test_hemm_blis.x \ test_herk_blis.x \ test_her2k_blis.x \ @@ -242,6 +243,7 @@ mkl: test_dotv_mkl.x \ test_scalv_mkl.x \ \ test_gemm_mkl.x \ + test_gemmt_mkl.x \ test_hemm_mkl.x \ test_herk_mkl.x \ test_her2k_mkl.x \ diff --git a/test/test_gemmt.c b/test/test_gemmt.c new file mode 100644 index 000000000..f74401e72 --- /dev/null +++ b/test/test_gemmt.c @@ -0,0 +1,529 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2020, Advanced Micro Devices, Inc. + + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of The University of Texas nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include "blis.h" +#include "cblas.h" + +//#define FILE_IN_OUT +//#define CBLAS +//#define PRINT +#define MATRIX_INITIALISATION +int main( int argc, char** argv ) +{ + obj_t a, b, c; + obj_t c_save; + obj_t alpha, beta; + dim_t n, k; + num_t dt; + int r, n_repeats; + trans_t transa; + trans_t transb; + uplo_t uploc; +#ifndef FILE_IN_OUT + dim_t p; + dim_t p_begin, p_end, p_inc; + int n_input, k_input; +#endif + + double dtime; + double dtime_save; + double gflops; +#ifdef FILE_IN_OUT + FILE* fin = NULL; + FILE* fout = NULL; + +#endif + //bli_init(); + + //bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); + + n_repeats = 3; + +#ifndef FILE_IN_OUT +#ifndef PRINT + p_begin = 48; + p_end = 10000; + p_inc = 192; + + n_input = -1; + k_input = -1; +#else + p_begin = 16; + p_end = 16; + p_inc = 1; + + k_input = 50; + n_input = 50; +#endif +#endif +#if 1 + //dt = BLIS_FLOAT; + dt = BLIS_DOUBLE; +#else + //dt = BLIS_SCOMPLEX; + dt = BLIS_DCOMPLEX; +#endif + + transa = BLIS_NO_TRANSPOSE; + transb = BLIS_NO_TRANSPOSE; + + uploc = BLIS_UPPER; + +#ifdef FILE_IN_OUT + if (argc < 3) + { + printf("Usage: ./test_gemmt_XX.x input.csv output.csv\n"); + exit(1); + } + fin = fopen(argv[1], "r"); + if (fin == NULL) + { + printf("Error opening the file %s\n", argv[1]); + exit(1); + } + fout = fopen(argv[2], "w"); + if (fout == NULL) + { + printf("Error opening output file %s\n", argv[2]); + exit(1); + } + fprintf(fout, "n\t k\t lda\t ldb\t ldc\t gflops\n"); + + + printf("~~~~~~~~~~_BLAS\t n\t k\t lda\t ldb\t ldc \t gflops\n"); + + inc_t cs_a; + inc_t cs_b; + inc_t cs_c; + + while (fscanf(fin, "%ld %ld %ld %ld %ld\n", &k, &n, &cs_a, &cs_b, &cs_c) == 5) + { + if ((n > cs_a) || (k > cs_b) || (n > cs_c)) continue; // leading dimension should be greater than number of rows + + bli_obj_create( dt, 1, 1, 0, 0, &alpha); + bli_obj_create( dt, 1, 1, 0, 0, &beta ); + + bli_obj_create( dt, n, k, 1, cs_a, &a ); + bli_obj_create( dt, k, n, 1, cs_b, &b ); + bli_obj_create( dt, n, n, 1, cs_c, &c ); + bli_obj_create( dt, n, n, 1, cs_c, &c_save ); +#ifdef MATRIX_INITIALISATION + bli_randm( &a ); + bli_randm( &b ); + bli_randm( &c ); +#endif + bli_obj_set_struc( BLIS_TRIANGULAR, &c ); + bli_obj_set_uplo( uploc, &c ); + + bli_obj_set_conjtrans( transa, &a); + bli_obj_set_conjtrans( transb, &b); + + //Randomize C and zero the unstored triangle to ensure the + //implementation reads only from the stored region. + bli_randm( &c ); + bli_mktrim( &c ); + + //bli_setsc( 0.0, -1, &alpha ); + //bli_setsc( 0.0, 1, &beta ); + + bli_setsc( 1, 0.0, &alpha ); + bli_setsc( 1, 0.0, &beta ); + +#else + for ( p = p_begin; p <= p_end; p += p_inc ) + { + if ( n_input < 0 ) n = p * ( dim_t )abs(n_input); + else n = ( dim_t ) n_input; + if ( k_input < 0 ) k = p * ( dim_t )abs(k_input); + else k = ( dim_t ) k_input; + + bli_obj_create( dt, 1, 1, 0, 0, &alpha ); + bli_obj_create( dt, 1, 1, 0, 0, &beta ); +#ifdef CBLAS + bli_obj_create( dt, n, k, k, 1, &a ); + bli_obj_create( dt, k, n, n, 1, &b ); + bli_obj_create( dt, n, n, n, 1, &c ); + bli_obj_create( dt, n, n, n, 1, &c_save ); +#else + bli_obj_create( dt, n, k, 1, n, &a ); + bli_obj_create( dt, k, n, 1, k, &b ); + bli_obj_create( dt, n, n, 1, n, &c ); + bli_obj_create( dt, n, n, 1, n, &c_save ); + +#endif + + bli_randm( &a ); + bli_randm( &b ); + bli_randm( &c ); + bli_obj_set_struc( BLIS_TRIANGULAR, &c ); + bli_obj_set_uplo( uploc, &c ); + + bli_obj_set_conjtrans( transa, &a ); + bli_obj_set_conjtrans( transb, &b ); + + //Randomize C and zero the unstored triangle to ensure the + //implementation reads only from the stored region. + + bli_randm( &c ); + bli_mktrim( &c ); + + bli_setsc( (0.9/1.0), 0.2, &alpha ); + bli_setsc( -(1.1/1.0), 0.3, &beta ); + +#endif + bli_copym( &c, &c_save ); + + dtime_save = DBL_MAX; + + for ( r = 0; r < n_repeats; ++r ) + { + bli_copym( &c_save, &c ); + + + dtime = bli_clock(); + + +#ifdef PRINT + bli_printm( "a", &a, "%4.1f", "," ); + bli_printm( "b", &b, "%4.1f", "," ); + bli_printm( "c", &c, "%4.1f", "," ); +#endif + +#ifdef BLIS + bli_gemmt( &alpha, + &a, + &b, + &beta, + &c ); + +#else + +#ifdef CBLAS + enum CBLAS_ORDER cblas_order; + enum CBLAS_UPLO cblas_uplo; + enum CBLAS_TRANSPOSE cblas_transa; + enum CBLAS_TRANSPOSE cblas_transb; + + if ( bli_obj_row_stride( &c ) == 1 ) + cblas_order = CblasColMajor; + else + cblas_order = CblasRowMajor; + if( bli_is_upper( uploc ) ) + cblas_uplo = CblasUpper; + else + cblas_uplo = CblasLower; + + if( bli_is_trans( transa ) ) + cblas_transa = CblasTrans; + else if( bli_is_conjtrans( transa ) ) + cblas_transa = CblasConjTrans; + else + cblas_transa = CblasNoTrans; + + if( bli_is_trans( transb ) ) + cblas_transb = CblasTrans; + else if( bli_is_conjtrans( transb ) ) + cblas_transb = CblasConjTrans; + else + cblas_transb = CblasNoTrans; +#else + + f77_char f77_transa; + f77_char f77_transb; + f77_char f77_uploc; + + bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); + bli_param_map_blis_to_netlib_trans( transb, &f77_transb ); + bli_param_map_blis_to_netlib_uplo( uploc, &f77_uploc ); +#endif + + if ( bli_is_float( dt ) ) + { +#ifdef CBLAS + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); + float* alphap = bli_obj_buffer( &alpha ); + float* ap = bli_obj_buffer( &a ); + float* bp = bli_obj_buffer( &b ); + float* betap = bli_obj_buffer( &beta ); + float* cp = bli_obj_buffer( &c ); + + cblas_sgemmt( cblas_order, + cblas_uplo, + cblas_transa, + cblas_transb, + nn, + kk, + *alphap, + ap, lda, + bp, ldb, + *betap, + cp, ldc ); + +#else + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + float* alphap = bli_obj_buffer( &alpha ); + float* ap = bli_obj_buffer( &a ); + float* bp = bli_obj_buffer( &b ); + float* betap = bli_obj_buffer( &beta ); + float* cp = bli_obj_buffer( &c ); + + sgemmt_( &f77_uploc, + &f77_transa, + &f77_transb, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + +#endif + } + else if ( bli_is_double( dt ) ) + { +#ifdef CBLAS + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); + double* alphap = bli_obj_buffer( &alpha ); + double* ap = bli_obj_buffer( &a ); + double* bp = bli_obj_buffer( &b ); + double* betap = bli_obj_buffer( &beta ); + double* cp = bli_obj_buffer( &c ); + + cblas_dgemmt( cblas_order, + cblas_uplo, + cblas_transa, + cblas_transb, + nn, + kk, + *alphap, + ap,lda, + bp, ldb, + *betap, + cp, ldc + ); +#else + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + double* alphap = bli_obj_buffer( &alpha ); + double* ap = bli_obj_buffer( &a ); + double* bp = bli_obj_buffer( &b ); + double* betap = bli_obj_buffer( &beta ); + double* cp = bli_obj_buffer( &c ); + + dgemmt_( &f77_uploc, + &f77_transa, + &f77_transb, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); +#endif + } + else if ( bli_is_scomplex( dt ) ) + { +#ifdef CBLAS + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); + scomplex* alphap = bli_obj_buffer( &alpha ); + scomplex* ap = bli_obj_buffer( &a ); + scomplex* bp = bli_obj_buffer( &b ); + scomplex* betap = bli_obj_buffer( &beta ); + scomplex* cp = bli_obj_buffer( &c ); + + cblas_cgemmt( cblas_order, + cblas_uplo, + cblas_transa, + cblas_transb, + nn, + kk, + alphap, + ap, lda, + bp, ldb, + betap, + cp, ldc ); + +#else + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + scomplex* alphap = bli_obj_buffer( &alpha ); + scomplex* ap = bli_obj_buffer( &a ); + scomplex* bp = bli_obj_buffer( &b ); + scomplex* betap = bli_obj_buffer( &beta ); + scomplex* cp = bli_obj_buffer( &c ); + + cgemmt_( &f77_uploc, + &f77_transa, + &f77_transb, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + +#endif + } + else if ( bli_is_dcomplex( dt ) ) + { +#ifdef CBLAS + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_row_stride( &a ); + f77_int ldb = bli_obj_row_stride( &b ); + f77_int ldc = bli_obj_row_stride( &c ); + dcomplex* alphap = bli_obj_buffer( &alpha ); + dcomplex* ap = bli_obj_buffer( &a ); + dcomplex* bp = bli_obj_buffer( &b ); + dcomplex* betap = bli_obj_buffer( &beta ); + dcomplex* cp = bli_obj_buffer( &c ); + + cblas_zgemmt( cblas_order, + cblas_uplo, + cblas_transa, + cblas_transb, + nn, + kk, + alphap, + ap, lda, + bp, ldb, + betap, + cp, ldc ); + +#else + + f77_int kk = bli_obj_width_after_trans( &a ); + f77_int nn = bli_obj_width( &c ); + f77_int lda = bli_obj_col_stride( &a ); + f77_int ldb = bli_obj_col_stride( &b ); + f77_int ldc = bli_obj_col_stride( &c ); + dcomplex* alphap = bli_obj_buffer( &alpha ); + dcomplex* ap = bli_obj_buffer( &a ); + dcomplex* bp = bli_obj_buffer( &b ); + dcomplex* betap = bli_obj_buffer( &beta ); + dcomplex* cp = bli_obj_buffer( &c ); + + zgemmt_( &f77_uploc, + &f77_transa, + &f77_transb, + &nn, + &kk, + alphap, + ap, &lda, + bp, &ldb, + betap, + cp, &ldc ); + +#endif + } +#endif + +#ifdef PRINT + bli_printm( "c after", &c, "%4.1f", "" ); + exit(1); +#endif + + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + + gflops = ( n * k * n ) / ( dtime_save * 1.0e9 ); + + if ( bli_is_complex( dt ) ) gflops *= 4.0; + +#ifdef BLIS + printf( "data_gemmt_blis" ); +#else + printf( "data_gemmt_%s", BLAS ); +#endif + + +#ifdef FILE_IN_OUT + printf("%4lu \t %4lu \t %4lu \t %4lu \t %4lu \t %6.3f\n", \ + ( unsigned long )n, + ( unsigned long )k, (unsigned long)cs_a, (unsigned long)cs_b, (unsigned long)cs_c, gflops ); + + + fprintf(fout, "%4lu \t %4lu \t %4lu \t %4lu \t %4lu \t %6.3f\n", \ + ( unsigned long )n, + ( unsigned long )k, (unsigned long)cs_a, (unsigned long)cs_b, (unsigned long)cs_c, gflops ); + fflush(fout); + +#else + printf( "( %2lu, 1:4 ) = [ %4lu %4lu %7.2f ];\n", + ( unsigned long )(p - p_begin)/p_inc + 1, + ( unsigned long )n, + ( unsigned long )k, gflops ); +#endif + bli_obj_free( &alpha ); + bli_obj_free( &beta ); + + bli_obj_free( &a ); + bli_obj_free( &b ); + bli_obj_free( &c ); + bli_obj_free( &c_save ); + } + + //bli_finalize(); +#ifdef FILE_IN_OUT + fclose(fin); + fclose(fout); +#endif + return 0; +}