Added a new field in cntx to store l3 threshold function pointers

Details:
- Adding threshold function pointers to cntx gives flexibility to choose
  different threshold functions for different configurations.
- In case of fat binary where configuration is decided at run-time,
  adding threshold functions under a macro enables these functions for
  all the configs under a family. This can be avoided by adding function
  pointers to cntx which can be queried from cntx during run-time
  based on the config chosen.

Change-Id: Iaf7e69e45ae5bb60e4d0f75c7542a91e1609773f
This commit is contained in:
Meghana Vankadari
2021-07-30 23:47:06 +05:30
parent 0cb552c8f8
commit 6bad157754
13 changed files with 319 additions and 97 deletions

View File

@@ -5,7 +5,7 @@
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -66,6 +66,17 @@ void bli_cntx_init_zen( cntx_t* cntx )
cntx
);
// Update the context with architecture specific threshold functions
bli_cntx_set_l3_thresh_funcs
(
2,
// GEMMT
BLIS_GEMMT, bli_cntx_gemmtsup_thresh_is_met_zen,
// SYRK
BLIS_SYRK, bli_cntx_syrksup_thresh_is_met_zen,
cntx
);
// Update the context with optimized level-1f kernels.
bli_cntx_set_l1f_kers
(

View File

@@ -63,6 +63,17 @@ void bli_cntx_init_zen2( cntx_t* cntx )
cntx
);
// Update the context with architecture specific threshold functions
bli_cntx_set_l3_thresh_funcs
(
2,
//gemmt
BLIS_GEMMT, bli_cntx_gemmtsup_thresh_is_met_zen,
//SYRK
BLIS_SYRK, bli_cntx_syrksup_thresh_is_met_zen,
cntx
);
// Update the context with optimized packm kernels.
bli_cntx_set_packm_kers
(

View File

@@ -63,6 +63,17 @@ void bli_cntx_init_zen3( cntx_t* cntx )
cntx
);
// Update the context with architecture specific threshold functions
bli_cntx_set_l3_thresh_funcs
(
2,
// GEMMT
BLIS_GEMMT, bli_cntx_gemmtsup_thresh_is_met_zen,
// SYRK
BLIS_SYRK, bli_cntx_syrksup_thresh_is_met_zen,
cntx
);
// packm kernels
bli_cntx_set_packm_kers
(

View File

@@ -5,6 +5,7 @@
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -100,3 +101,14 @@ INSERT_GENTDEF( trsm )
#endif
// These function pointers are used to hold addresses of functions
// that decide which algorithm to choose between SUP and native
// implementations based on input dimensions. These are stored
// in cntx of respective configurations.
typedef bool (*thresh_func_ft)
(
obj_t* a,
obj_t* b,
obj_t* c,
cntx_t* cntx
);

View File

@@ -71,8 +71,6 @@ err_t bli_gemmsup
return BLIS_FAILURE;
}
const dim_t m = bli_obj_length( c );
const dim_t n = bli_obj_width( c );
trans_t transa = bli_obj_conjtrans_status( a );
trans_t transb = bli_obj_conjtrans_status( b );
@@ -104,30 +102,14 @@ err_t bli_gemmsup
// that function assumes the context pointer is valid.
if ( cntx == NULL ) cntx = bli_gks_query_cntx();
// Return early if a microkernel preference-induced transposition would
// have been performed and shifted the dimensions outside of the space
// of sup-handled problems.
if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( c, BLIS_GEMM_UKR, cntx ) )
{
const num_t dt = bli_obj_dt( c );
const dim_t k = bli_obj_width_after_trans( a );
thresh_func_ft func_fp;
// Pass in m and n reversed, which simulates a transposition of the
// entire operation pursuant to the microkernel storage preference.
if ( !bli_cntx_l3_sup_thresh_is_met( dt, n, m, k, cntx ) ) {
AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "SUP - Trasposition results in sizes beyond SUP thresholds.");
return BLIS_FAILURE;
}
}
else // ukr_prefers_storage_of( c, ... )
{
const num_t dt = bli_obj_dt( c );
const dim_t k = bli_obj_width_after_trans( a );
func_fp = bli_cntx_get_l3_thresh_func(BLIS_GEMM, cntx);
if ( !bli_cntx_l3_sup_thresh_is_met( dt, m, n, k, cntx ) ) {
AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "SUP - Sizes are beyond SUP thresholds.");
return BLIS_FAILURE;
}
// Return early if the sizes are beyond SUP thresholds
if ( !func_fp( a, b, c, cntx ) ) {
AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "SUP - Sizes are beyond SUP thresholds.");
return BLIS_FAILURE;
}
// Initialize a local runtime with global settings if necessary. Note
@@ -215,7 +197,6 @@ err_t bli_gemmtsup
return BLIS_FAILURE;
}
const dim_t n = bli_obj_width( c );
trans_t transa = bli_obj_conjtrans_status( a );
trans_t transb = bli_obj_conjtrans_status( b );
@@ -247,10 +228,11 @@ err_t bli_gemmtsup
// that function assumes the context pointer is valid.
if ( cntx == NULL ) cntx = bli_gks_query_cntx();
num_t dt = bli_obj_dt(c);
dim_t k = bli_obj_width_after_trans( a );
thresh_func_ft func_fp;
if ( !bli_cntx_gemmtsup_thresh_is_met( dt, n, k, cntx ) )
func_fp = bli_cntx_get_l3_thresh_func(BLIS_GEMMT, cntx);
if ( !func_fp( a, b, c, cntx ) )
{
AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "SUP - Sizes beyond SUP thresholds.");
return BLIS_FAILURE;
@@ -343,7 +325,6 @@ err_t bli_syrksup
return BLIS_FAILURE;
}
const dim_t n = bli_obj_width( c );
trans_t transa = bli_obj_conjtrans_status( a );
//Don't use sup for currently unsupported storage types in cgemmsup
@@ -372,10 +353,8 @@ err_t bli_syrksup
// that function assumes the context pointer is valid.
if ( cntx == NULL ) cntx = bli_gks_query_cntx();
num_t dt = bli_obj_dt( c );
dim_t k = bli_obj_width_after_trans( a );
if( !bli_cntx_syrksup_thresh_is_met( dt, n, k, stor_id, cntx))
thresh_func_ft func_fp = bli_cntx_get_l3_thresh_func(BLIS_SYRK, cntx);
if( !func_fp( a, &at_local, c, cntx))
{
AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "SUP - sizes beyond SUP thresholds.");
return BLIS_FAILURE;

View File

@@ -5,7 +5,7 @@
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2020, Advanced Micro Devices, Inc.
Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -1605,6 +1605,105 @@ void bli_cntx_set_l1v_kers( dim_t n_kers, ... )
// -----------------------------------------------------------------------------
void bli_cntx_set_l3_thresh_funcs( dim_t n_funcs, ... )
{
// This function can be called from the bli_cntx_init_*() function for
// a particular architecture if the kernel developer wishes to use
// non-default level-3 threshold functions. It should be called after
// bli_cntx_init_defaults() so that the context begins with default
// functionss across all operations.
/* Example prototypes:
void bli_cntx_set_l3_thresh_funcs
(
dim_t n_funcs,
opid_t op1_id, void_fp ker0_fp,
opid_t op2_id, void_fp ker1_fp,
opid_t op2_id, void_fp ker2_fp,
...
cntx_t* cntx
);
*/
va_list args;
dim_t i;
// Allocate some temporary local arrays.
#ifdef BLIS_ENABLE_MEM_TRACING
printf( "bli_cntx_set_l3_thresh_funcs(): " );
#endif
l1vkr_t* func_ids = bli_malloc_intl( n_funcs * sizeof( opid_t ) );
#ifdef BLIS_ENABLE_MEM_TRACING
printf( "bli_cntx_set_l3_thresh_funcs(): " );
#endif
void_fp* func_fps = bli_malloc_intl( n_funcs * sizeof( void_fp ) );
// -- Begin variable argument section --
// Initialize variable argument environment.
va_start( args, n_funcs );
// Process n_funcs tuples.
for ( i = 0; i < n_funcs; ++i )
{
// Here, we query the variable argument list for:
// - the opid_t of the function we're about to process,
// - the function pointer
// that we need to store to the context.
const opid_t op_id = ( opid_t )va_arg( args, opid_t );
void_fp func_fp = ( void_fp )va_arg( args, void_fp );
// Store the values in our temporary arrays.
func_ids[ i ] = op_id;
func_fps[ i ] = func_fp;
}
// The last argument should be the context pointer.
cntx_t* cntx = ( cntx_t* )va_arg( args, cntx_t* );
// Shutdown variable argument environment and clean up stack.
va_end( args );
// -- End variable argument section --
// Query the context for the address of:
// - the level-3 threshold func array
void_fp* cntx_l3_thresh_funcs = bli_cntx_l3_thresh_funcs_buf( cntx );
// Now that we have the context address, we want to copy the values
// from the temporary buffers into the corresponding buffers in the
// context.
// Process each blocksize id tuple provided.
for ( i = 0; i < n_funcs; ++i )
{
// Read the current func id, and function pointer.
const opid_t func_id = func_ids[ i ];
void_fp func_fp = func_fps[ i ];
// Store function pointer in cntx
cntx_l3_thresh_funcs[ func_id ] = func_fp;
}
// Free the temporary local arrays.
#ifdef BLIS_ENABLE_MEM_TRACING
printf( "bli_cntx_set_l3_thresh_funcs(): " );
#endif
bli_free_intl( func_ids );
#ifdef BLIS_ENABLE_MEM_TRACING
printf( "bli_cntx_set_l3_thresh_funcs(): " );
#endif
bli_free_intl( func_fps );
}
// -----------------------------------------------------------------------------
void bli_cntx_set_packm_kers( dim_t n_kers, ... )
{
// This function can be called from the bli_cntx_init_*() function for

View File

@@ -128,6 +128,10 @@ BLIS_INLINE func_t* bli_cntx_l1v_kers_buf( cntx_t* cntx )
{
return cntx->l1v_kers;
}
BLIS_INLINE void** bli_cntx_l3_thresh_funcs_buf(cntx_t* cntx )
{
return cntx->l3_thresh_funcs;
}
BLIS_INLINE func_t* bli_cntx_packm_kers_buf( cntx_t* cntx )
{
return cntx->packm_kers;
@@ -307,65 +311,6 @@ BLIS_INLINE dim_t bli_cntx_get_l3_sup_thresh_dt( num_t dt, threshid_t thresh_id,
return thresh_dt;
}
BLIS_INLINE bool bli_cntx_l3_sup_thresh_is_met( num_t dt, dim_t m, dim_t n, dim_t k, cntx_t* cntx )
{
if ( m < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_MT, cntx ) ) return TRUE;
if ( n < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_NT, cntx ) ) return TRUE;
if ( k < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_KT, cntx ) ) return TRUE;
return FALSE;
}
// -- gemmt specific function
BLIS_INLINE bool bli_cntx_gemmtsup_thresh_is_met( num_t dt, dim_t n, dim_t k, cntx_t* cntx )
{
#ifdef BLIS_CONFIG_EPYC
if( bli_is_double( dt ))
{
if ( n < 300 ) return TRUE;
if ( (k / n ) > 50 ) return TRUE;
return FALSE;
}
else if ( bli_is_dcomplex( dt ) )
{
if ( n < 100 ) return TRUE;
else return FALSE;
}
else
return bli_cntx_l3_sup_thresh_is_met( dt, n, n, k, cntx );
#else
return bli_cntx_l3_sup_thresh_is_met( dt, n, n, k, cntx );
#endif
}
// -- syrk specific function
BLIS_INLINE bool bli_cntx_syrksup_thresh_is_met( num_t dt, dim_t n, dim_t k, stor3_t stor_id, cntx_t* cntx )
{
#ifdef BLIS_CONFIG_EPYC
if( bli_is_double( dt ) )
{
if( ( stor_id == BLIS_RRC ) || ( stor_id == BLIS_CCR ) )
{
if( n < 140) return TRUE;
else if( ( n < 200 ) && ( k < 100 ) ) return TRUE;
else if( ( n <= 450 ) && ( k < 50 ) ) return TRUE;
else return FALSE;
}
else
{
if( n < 150 ) return TRUE;
else return FALSE;
}
}
else
return bli_cntx_l3_sup_thresh_is_met( dt, n, n, k, cntx );
#else
//copied gemm thresholds temporarily. These needs to be derived for syrk.
return bli_cntx_l3_sup_thresh_is_met( dt, n, n, k, cntx );
#endif
}
// -----------------------------------------------------------------------------
BLIS_INLINE void* bli_cntx_get_l3_sup_handler( opid_t op, cntx_t* cntx )
@@ -520,6 +465,14 @@ BLIS_INLINE void_fp bli_cntx_get_packm_ker_dt( num_t dt, l1mkr_t ker_id, cntx_t*
return fp;
}
BLIS_INLINE void* bli_cntx_get_l3_thresh_func( opid_t func_id, cntx_t* cntx )
{
void** funcs = bli_cntx_l3_thresh_funcs_buf( cntx );
void* func = funcs[ func_id ];
return func;
}
BLIS_INLINE func_t* bli_cntx_get_unpackm_kers( l1mkr_t ker_id, cntx_t* cntx )
{
func_t* func = NULL;
@@ -599,7 +552,6 @@ BLIS_INLINE bool bli_cntx_l3_nat_ukr_dislikes_storage_of( obj_t* obj, l3ukr_t uk
}
// -----------------------------------------------------------------------------
BLIS_INLINE bool bli_cntx_l3_vir_ukr_prefers_rows_dt( num_t dt, l3ukr_t ukr_id, cntx_t* cntx )
{
// For induced methods, return the ukernel storage preferences of the
@@ -650,6 +602,34 @@ BLIS_INLINE bool bli_cntx_l3_vir_ukr_dislikes_storage_of( obj_t* obj, l3ukr_t uk
}
// -----------------------------------------------------------------------------
BLIS_INLINE bool bli_cntx_l3_sup_thresh_is_met( obj_t* a, obj_t* b, obj_t* c, cntx_t* cntx )
{
num_t dt = bli_obj_dt( c );
dim_t k = bli_obj_width_after_trans( a );
dim_t m, n;
if(bli_cntx_l3_vir_ukr_dislikes_storage_of(c, BLIS_GEMM_UKR, cntx ) )
{
m = bli_obj_width(c);
n = bli_obj_length(c);
}
else
{
m = bli_obj_length( c );
n = bli_obj_width( c );
}
if ( m < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_MT, cntx ) ) return TRUE;
if ( n < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_NT, cntx ) ) return TRUE;
if ( k < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_KT, cntx ) ) return TRUE;
return FALSE;
}
// -----------------------------------------------------------------------------
BLIS_INLINE bool bli_cntx_l3_sup_ker_prefers_rows_dt( num_t dt, stor3_t stor_id, cntx_t* cntx )
{
@@ -814,6 +794,7 @@ BLIS_EXPORT_BLIS void bli_cntx_set_l3_sup_kers( dim_t n_ukrs, ... );
BLIS_EXPORT_BLIS void bli_cntx_set_l1f_kers( dim_t n_kers, ... );
BLIS_EXPORT_BLIS void bli_cntx_set_l1v_kers( dim_t n_kers, ... );
BLIS_EXPORT_BLIS void bli_cntx_set_packm_kers( dim_t n_kers, ... );
BLIS_EXPORT_BLIS void bli_cntx_set_l3_thresh_funcs( dim_t n_funcs, ... );
BLIS_EXPORT_BLIS void bli_cntx_print( cntx_t* cntx );

View File

@@ -1432,6 +1432,7 @@ typedef struct cntx_s
func_t l3_vir_ukrs[ BLIS_NUM_LEVEL3_UKRS ];
func_t l3_nat_ukrs[ BLIS_NUM_LEVEL3_UKRS ];
mbool_t l3_nat_ukrs_prefs[ BLIS_NUM_LEVEL3_UKRS ];
void* l3_thresh_funcs[ BLIS_NUM_LEVEL3_OPS ];
blksz_t l3_sup_thresh[ BLIS_NUM_THRESH ];
void* l3_sup_handlers[ BLIS_NUM_LEVEL3_OPS ];

View File

@@ -1,7 +1,7 @@
##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.##
set(SUBDIRECTORIES "1" "1f" "1m" "2" "3")
set(SUBDIRECTORIES "1" "1f" "1m" "2" "3" "util")
#Add all subdirectories
foreach(VAR ${SUBDIRECTORIES})

View File

@@ -273,3 +273,19 @@ void bli_dgemm_ref_k1_nn
cntl_t* cntl
);
// threshold functions
bool bli_cntx_gemmtsup_thresh_is_met_zen
(
obj_t* a,
obj_t* b,
obj_t* c,
cntx_t* cntx
);
bool bli_cntx_syrksup_thresh_is_met_zen
(
obj_t* a,
obj_t* b,
obj_t* c,
cntx_t* cntx
);

View File

@@ -0,0 +1,6 @@
##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.##
target_sources("${PROJECT_NAME}"
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/bli_thresh_funcs_zen.c
)

View File

@@ -0,0 +1,88 @@
/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
// -- gemmt specific function
bool bli_cntx_gemmtsup_thresh_is_met_zen( obj_t* a, obj_t* b, obj_t* c, cntx_t* cntx )
{
num_t dt = bli_obj_dt( c );
dim_t n = bli_obj_length( c );
dim_t k = bli_obj_width_after_trans( a );
if( bli_is_double( dt ))
{
if ( n < 300 ) return TRUE;
if ( (k / n ) > 50 ) return TRUE;
return FALSE;
}
else if ( bli_is_dcomplex( dt ) )
{
if ( n < 100 ) return TRUE;
else return FALSE;
}
else
return bli_cntx_l3_sup_thresh_is_met( a, b, c, cntx );
}
// -- syrk specific function
bool bli_cntx_syrksup_thresh_is_met_zen( obj_t* a, obj_t* b, obj_t* c, cntx_t* cntx )
{
num_t dt = bli_obj_dt( c );
dim_t n = bli_obj_length( c );
dim_t k = bli_obj_width_after_trans( a );
stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b );
if( bli_is_double( dt ) )
{
if( ( stor_id == BLIS_RRC ) || ( stor_id == BLIS_CCR ) )
{
if( n < 140) return TRUE;
else if( ( n < 200 ) && ( k < 100 ) ) return TRUE;
else if( ( n <= 450 ) && ( k < 50 ) ) return TRUE;
else return FALSE;
}
else
{
if( n < 150 ) return TRUE;
else return FALSE;
}
}
else
return bli_cntx_l3_sup_thresh_is_met( a, b, c, cntx );
}

View File

@@ -5,7 +5,7 @@
libraries.
Copyright (C) 2014, The University of Texas at Austin
Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -457,6 +457,13 @@ void GENBARNAME(cntx_init)
cntx
);
// -- Set level-3 threshold functions -------------------------------------
vfuncs = bli_cntx_l3_thresh_funcs_buf( cntx );
// Initialize all of the function pointers to default function
for ( i = 0; i < BLIS_NUM_LEVEL3_OPS; ++i )
vfuncs[ i ] = bli_cntx_l3_sup_thresh_is_met;
// -- Set level-3 small/unpacked handlers ----------------------------------