mirror of
https://github.com/amd/blis.git
synced 2026-04-19 23:28:52 +00:00
Added a new field in cntx to store l3 threshold function pointers
Details: - Adding threshold function pointers to cntx gives flexibility to choose different threshold functions for different configurations. - In case of fat binary where configuration is decided at run-time, adding threshold functions under a macro enables these functions for all the configs under a family. This can be avoided by adding function pointers to cntx which can be queried from cntx during run-time based on the config chosen. Change-Id: Iaf7e69e45ae5bb60e4d0f75c7542a91e1609773f
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2014, The University of Texas at Austin
|
||||
Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -66,6 +66,17 @@ void bli_cntx_init_zen( cntx_t* cntx )
|
||||
cntx
|
||||
);
|
||||
|
||||
// Update the context with architecture specific threshold functions
|
||||
bli_cntx_set_l3_thresh_funcs
|
||||
(
|
||||
2,
|
||||
// GEMMT
|
||||
BLIS_GEMMT, bli_cntx_gemmtsup_thresh_is_met_zen,
|
||||
// SYRK
|
||||
BLIS_SYRK, bli_cntx_syrksup_thresh_is_met_zen,
|
||||
cntx
|
||||
);
|
||||
|
||||
// Update the context with optimized level-1f kernels.
|
||||
bli_cntx_set_l1f_kers
|
||||
(
|
||||
|
||||
@@ -63,6 +63,17 @@ void bli_cntx_init_zen2( cntx_t* cntx )
|
||||
cntx
|
||||
);
|
||||
|
||||
// Update the context with architecture specific threshold functions
|
||||
bli_cntx_set_l3_thresh_funcs
|
||||
(
|
||||
2,
|
||||
//gemmt
|
||||
BLIS_GEMMT, bli_cntx_gemmtsup_thresh_is_met_zen,
|
||||
//SYRK
|
||||
BLIS_SYRK, bli_cntx_syrksup_thresh_is_met_zen,
|
||||
cntx
|
||||
);
|
||||
|
||||
// Update the context with optimized packm kernels.
|
||||
bli_cntx_set_packm_kers
|
||||
(
|
||||
|
||||
@@ -63,6 +63,17 @@ void bli_cntx_init_zen3( cntx_t* cntx )
|
||||
cntx
|
||||
);
|
||||
|
||||
// Update the context with architecture specific threshold functions
|
||||
bli_cntx_set_l3_thresh_funcs
|
||||
(
|
||||
2,
|
||||
// GEMMT
|
||||
BLIS_GEMMT, bli_cntx_gemmtsup_thresh_is_met_zen,
|
||||
// SYRK
|
||||
BLIS_SYRK, bli_cntx_syrksup_thresh_is_met_zen,
|
||||
cntx
|
||||
);
|
||||
|
||||
// packm kernels
|
||||
bli_cntx_set_packm_kers
|
||||
(
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2014, The University of Texas at Austin
|
||||
Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -100,3 +101,14 @@ INSERT_GENTDEF( trsm )
|
||||
|
||||
#endif
|
||||
|
||||
// These function pointers are used to hold addresses of functions
|
||||
// that decide which algorithm to choose between SUP and native
|
||||
// implementations based on input dimensions. These are stored
|
||||
// in cntx of respective configurations.
|
||||
typedef bool (*thresh_func_ft)
|
||||
(
|
||||
obj_t* a,
|
||||
obj_t* b,
|
||||
obj_t* c,
|
||||
cntx_t* cntx
|
||||
);
|
||||
|
||||
@@ -71,8 +71,6 @@ err_t bli_gemmsup
|
||||
return BLIS_FAILURE;
|
||||
}
|
||||
|
||||
const dim_t m = bli_obj_length( c );
|
||||
const dim_t n = bli_obj_width( c );
|
||||
trans_t transa = bli_obj_conjtrans_status( a );
|
||||
trans_t transb = bli_obj_conjtrans_status( b );
|
||||
|
||||
@@ -104,30 +102,14 @@ err_t bli_gemmsup
|
||||
// that function assumes the context pointer is valid.
|
||||
if ( cntx == NULL ) cntx = bli_gks_query_cntx();
|
||||
|
||||
// Return early if a microkernel preference-induced transposition would
|
||||
// have been performed and shifted the dimensions outside of the space
|
||||
// of sup-handled problems.
|
||||
if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( c, BLIS_GEMM_UKR, cntx ) )
|
||||
{
|
||||
const num_t dt = bli_obj_dt( c );
|
||||
const dim_t k = bli_obj_width_after_trans( a );
|
||||
thresh_func_ft func_fp;
|
||||
|
||||
// Pass in m and n reversed, which simulates a transposition of the
|
||||
// entire operation pursuant to the microkernel storage preference.
|
||||
if ( !bli_cntx_l3_sup_thresh_is_met( dt, n, m, k, cntx ) ) {
|
||||
AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "SUP - Trasposition results in sizes beyond SUP thresholds.");
|
||||
return BLIS_FAILURE;
|
||||
}
|
||||
}
|
||||
else // ukr_prefers_storage_of( c, ... )
|
||||
{
|
||||
const num_t dt = bli_obj_dt( c );
|
||||
const dim_t k = bli_obj_width_after_trans( a );
|
||||
func_fp = bli_cntx_get_l3_thresh_func(BLIS_GEMM, cntx);
|
||||
|
||||
if ( !bli_cntx_l3_sup_thresh_is_met( dt, m, n, k, cntx ) ) {
|
||||
AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "SUP - Sizes are beyond SUP thresholds.");
|
||||
return BLIS_FAILURE;
|
||||
}
|
||||
// Return early if the sizes are beyond SUP thresholds
|
||||
if ( !func_fp( a, b, c, cntx ) ) {
|
||||
AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "SUP - Sizes are beyond SUP thresholds.");
|
||||
return BLIS_FAILURE;
|
||||
}
|
||||
|
||||
// Initialize a local runtime with global settings if necessary. Note
|
||||
@@ -215,7 +197,6 @@ err_t bli_gemmtsup
|
||||
return BLIS_FAILURE;
|
||||
}
|
||||
|
||||
const dim_t n = bli_obj_width( c );
|
||||
trans_t transa = bli_obj_conjtrans_status( a );
|
||||
trans_t transb = bli_obj_conjtrans_status( b );
|
||||
|
||||
@@ -247,10 +228,11 @@ err_t bli_gemmtsup
|
||||
// that function assumes the context pointer is valid.
|
||||
if ( cntx == NULL ) cntx = bli_gks_query_cntx();
|
||||
|
||||
num_t dt = bli_obj_dt(c);
|
||||
dim_t k = bli_obj_width_after_trans( a );
|
||||
thresh_func_ft func_fp;
|
||||
|
||||
if ( !bli_cntx_gemmtsup_thresh_is_met( dt, n, k, cntx ) )
|
||||
func_fp = bli_cntx_get_l3_thresh_func(BLIS_GEMMT, cntx);
|
||||
|
||||
if ( !func_fp( a, b, c, cntx ) )
|
||||
{
|
||||
AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "SUP - Sizes beyond SUP thresholds.");
|
||||
return BLIS_FAILURE;
|
||||
@@ -343,7 +325,6 @@ err_t bli_syrksup
|
||||
return BLIS_FAILURE;
|
||||
}
|
||||
|
||||
const dim_t n = bli_obj_width( c );
|
||||
trans_t transa = bli_obj_conjtrans_status( a );
|
||||
|
||||
//Don't use sup for currently unsupported storage types in cgemmsup
|
||||
@@ -372,10 +353,8 @@ err_t bli_syrksup
|
||||
// that function assumes the context pointer is valid.
|
||||
if ( cntx == NULL ) cntx = bli_gks_query_cntx();
|
||||
|
||||
num_t dt = bli_obj_dt( c );
|
||||
dim_t k = bli_obj_width_after_trans( a );
|
||||
|
||||
if( !bli_cntx_syrksup_thresh_is_met( dt, n, k, stor_id, cntx))
|
||||
thresh_func_ft func_fp = bli_cntx_get_l3_thresh_func(BLIS_SYRK, cntx);
|
||||
if( !func_fp( a, &at_local, c, cntx))
|
||||
{
|
||||
AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "SUP - sizes beyond SUP thresholds.");
|
||||
return BLIS_FAILURE;
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2014, The University of Texas at Austin
|
||||
Copyright (C) 2020, Advanced Micro Devices, Inc.
|
||||
Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -1605,6 +1605,105 @@ void bli_cntx_set_l1v_kers( dim_t n_kers, ... )
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
void bli_cntx_set_l3_thresh_funcs( dim_t n_funcs, ... )
|
||||
{
|
||||
// This function can be called from the bli_cntx_init_*() function for
|
||||
// a particular architecture if the kernel developer wishes to use
|
||||
// non-default level-3 threshold functions. It should be called after
|
||||
// bli_cntx_init_defaults() so that the context begins with default
|
||||
// functionss across all operations.
|
||||
|
||||
/* Example prototypes:
|
||||
|
||||
void bli_cntx_set_l3_thresh_funcs
|
||||
(
|
||||
dim_t n_funcs,
|
||||
opid_t op1_id, void_fp ker0_fp,
|
||||
opid_t op2_id, void_fp ker1_fp,
|
||||
opid_t op2_id, void_fp ker2_fp,
|
||||
...
|
||||
cntx_t* cntx
|
||||
);
|
||||
*/
|
||||
|
||||
va_list args;
|
||||
dim_t i;
|
||||
|
||||
// Allocate some temporary local arrays.
|
||||
|
||||
#ifdef BLIS_ENABLE_MEM_TRACING
|
||||
printf( "bli_cntx_set_l3_thresh_funcs(): " );
|
||||
#endif
|
||||
l1vkr_t* func_ids = bli_malloc_intl( n_funcs * sizeof( opid_t ) );
|
||||
|
||||
#ifdef BLIS_ENABLE_MEM_TRACING
|
||||
printf( "bli_cntx_set_l3_thresh_funcs(): " );
|
||||
#endif
|
||||
void_fp* func_fps = bli_malloc_intl( n_funcs * sizeof( void_fp ) );
|
||||
|
||||
// -- Begin variable argument section --
|
||||
|
||||
// Initialize variable argument environment.
|
||||
va_start( args, n_funcs );
|
||||
|
||||
// Process n_funcs tuples.
|
||||
for ( i = 0; i < n_funcs; ++i )
|
||||
{
|
||||
// Here, we query the variable argument list for:
|
||||
// - the opid_t of the function we're about to process,
|
||||
// - the function pointer
|
||||
// that we need to store to the context.
|
||||
const opid_t op_id = ( opid_t )va_arg( args, opid_t );
|
||||
void_fp func_fp = ( void_fp )va_arg( args, void_fp );
|
||||
|
||||
// Store the values in our temporary arrays.
|
||||
func_ids[ i ] = op_id;
|
||||
func_fps[ i ] = func_fp;
|
||||
}
|
||||
|
||||
// The last argument should be the context pointer.
|
||||
cntx_t* cntx = ( cntx_t* )va_arg( args, cntx_t* );
|
||||
|
||||
// Shutdown variable argument environment and clean up stack.
|
||||
va_end( args );
|
||||
|
||||
// -- End variable argument section --
|
||||
|
||||
// Query the context for the address of:
|
||||
// - the level-3 threshold func array
|
||||
void_fp* cntx_l3_thresh_funcs = bli_cntx_l3_thresh_funcs_buf( cntx );
|
||||
|
||||
// Now that we have the context address, we want to copy the values
|
||||
// from the temporary buffers into the corresponding buffers in the
|
||||
// context.
|
||||
|
||||
// Process each blocksize id tuple provided.
|
||||
for ( i = 0; i < n_funcs; ++i )
|
||||
{
|
||||
// Read the current func id, and function pointer.
|
||||
const opid_t func_id = func_ids[ i ];
|
||||
void_fp func_fp = func_fps[ i ];
|
||||
|
||||
// Store function pointer in cntx
|
||||
cntx_l3_thresh_funcs[ func_id ] = func_fp;
|
||||
|
||||
}
|
||||
|
||||
// Free the temporary local arrays.
|
||||
|
||||
#ifdef BLIS_ENABLE_MEM_TRACING
|
||||
printf( "bli_cntx_set_l3_thresh_funcs(): " );
|
||||
#endif
|
||||
bli_free_intl( func_ids );
|
||||
|
||||
#ifdef BLIS_ENABLE_MEM_TRACING
|
||||
printf( "bli_cntx_set_l3_thresh_funcs(): " );
|
||||
#endif
|
||||
bli_free_intl( func_fps );
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
void bli_cntx_set_packm_kers( dim_t n_kers, ... )
|
||||
{
|
||||
// This function can be called from the bli_cntx_init_*() function for
|
||||
|
||||
@@ -128,6 +128,10 @@ BLIS_INLINE func_t* bli_cntx_l1v_kers_buf( cntx_t* cntx )
|
||||
{
|
||||
return cntx->l1v_kers;
|
||||
}
|
||||
BLIS_INLINE void** bli_cntx_l3_thresh_funcs_buf(cntx_t* cntx )
|
||||
{
|
||||
return cntx->l3_thresh_funcs;
|
||||
}
|
||||
BLIS_INLINE func_t* bli_cntx_packm_kers_buf( cntx_t* cntx )
|
||||
{
|
||||
return cntx->packm_kers;
|
||||
@@ -307,65 +311,6 @@ BLIS_INLINE dim_t bli_cntx_get_l3_sup_thresh_dt( num_t dt, threshid_t thresh_id,
|
||||
return thresh_dt;
|
||||
}
|
||||
|
||||
BLIS_INLINE bool bli_cntx_l3_sup_thresh_is_met( num_t dt, dim_t m, dim_t n, dim_t k, cntx_t* cntx )
|
||||
{
|
||||
if ( m < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_MT, cntx ) ) return TRUE;
|
||||
if ( n < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_NT, cntx ) ) return TRUE;
|
||||
if ( k < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_KT, cntx ) ) return TRUE;
|
||||
|
||||
return FALSE;
|
||||
}
|
||||
|
||||
// -- gemmt specific function
|
||||
BLIS_INLINE bool bli_cntx_gemmtsup_thresh_is_met( num_t dt, dim_t n, dim_t k, cntx_t* cntx )
|
||||
{
|
||||
#ifdef BLIS_CONFIG_EPYC
|
||||
if( bli_is_double( dt ))
|
||||
{
|
||||
if ( n < 300 ) return TRUE;
|
||||
if ( (k / n ) > 50 ) return TRUE;
|
||||
|
||||
return FALSE;
|
||||
}
|
||||
else if ( bli_is_dcomplex( dt ) )
|
||||
{
|
||||
if ( n < 100 ) return TRUE;
|
||||
else return FALSE;
|
||||
}
|
||||
else
|
||||
return bli_cntx_l3_sup_thresh_is_met( dt, n, n, k, cntx );
|
||||
#else
|
||||
return bli_cntx_l3_sup_thresh_is_met( dt, n, n, k, cntx );
|
||||
#endif
|
||||
}
|
||||
|
||||
// -- syrk specific function
|
||||
BLIS_INLINE bool bli_cntx_syrksup_thresh_is_met( num_t dt, dim_t n, dim_t k, stor3_t stor_id, cntx_t* cntx )
|
||||
{
|
||||
#ifdef BLIS_CONFIG_EPYC
|
||||
if( bli_is_double( dt ) )
|
||||
{
|
||||
if( ( stor_id == BLIS_RRC ) || ( stor_id == BLIS_CCR ) )
|
||||
{
|
||||
if( n < 140) return TRUE;
|
||||
else if( ( n < 200 ) && ( k < 100 ) ) return TRUE;
|
||||
else if( ( n <= 450 ) && ( k < 50 ) ) return TRUE;
|
||||
else return FALSE;
|
||||
}
|
||||
else
|
||||
{
|
||||
if( n < 150 ) return TRUE;
|
||||
else return FALSE;
|
||||
}
|
||||
}
|
||||
else
|
||||
return bli_cntx_l3_sup_thresh_is_met( dt, n, n, k, cntx );
|
||||
#else
|
||||
//copied gemm thresholds temporarily. These needs to be derived for syrk.
|
||||
return bli_cntx_l3_sup_thresh_is_met( dt, n, n, k, cntx );
|
||||
#endif
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
BLIS_INLINE void* bli_cntx_get_l3_sup_handler( opid_t op, cntx_t* cntx )
|
||||
@@ -520,6 +465,14 @@ BLIS_INLINE void_fp bli_cntx_get_packm_ker_dt( num_t dt, l1mkr_t ker_id, cntx_t*
|
||||
return fp;
|
||||
}
|
||||
|
||||
BLIS_INLINE void* bli_cntx_get_l3_thresh_func( opid_t func_id, cntx_t* cntx )
|
||||
{
|
||||
void** funcs = bli_cntx_l3_thresh_funcs_buf( cntx );
|
||||
void* func = funcs[ func_id ];
|
||||
|
||||
return func;
|
||||
}
|
||||
|
||||
BLIS_INLINE func_t* bli_cntx_get_unpackm_kers( l1mkr_t ker_id, cntx_t* cntx )
|
||||
{
|
||||
func_t* func = NULL;
|
||||
@@ -599,7 +552,6 @@ BLIS_INLINE bool bli_cntx_l3_nat_ukr_dislikes_storage_of( obj_t* obj, l3ukr_t uk
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
BLIS_INLINE bool bli_cntx_l3_vir_ukr_prefers_rows_dt( num_t dt, l3ukr_t ukr_id, cntx_t* cntx )
|
||||
{
|
||||
// For induced methods, return the ukernel storage preferences of the
|
||||
@@ -650,6 +602,34 @@ BLIS_INLINE bool bli_cntx_l3_vir_ukr_dislikes_storage_of( obj_t* obj, l3ukr_t uk
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
BLIS_INLINE bool bli_cntx_l3_sup_thresh_is_met( obj_t* a, obj_t* b, obj_t* c, cntx_t* cntx )
|
||||
{
|
||||
num_t dt = bli_obj_dt( c );
|
||||
dim_t k = bli_obj_width_after_trans( a );
|
||||
|
||||
dim_t m, n;
|
||||
|
||||
if(bli_cntx_l3_vir_ukr_dislikes_storage_of(c, BLIS_GEMM_UKR, cntx ) )
|
||||
{
|
||||
m = bli_obj_width(c);
|
||||
n = bli_obj_length(c);
|
||||
}
|
||||
else
|
||||
{
|
||||
m = bli_obj_length( c );
|
||||
n = bli_obj_width( c );
|
||||
|
||||
}
|
||||
|
||||
if ( m < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_MT, cntx ) ) return TRUE;
|
||||
if ( n < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_NT, cntx ) ) return TRUE;
|
||||
if ( k < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_KT, cntx ) ) return TRUE;
|
||||
|
||||
return FALSE;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
BLIS_INLINE bool bli_cntx_l3_sup_ker_prefers_rows_dt( num_t dt, stor3_t stor_id, cntx_t* cntx )
|
||||
{
|
||||
@@ -814,6 +794,7 @@ BLIS_EXPORT_BLIS void bli_cntx_set_l3_sup_kers( dim_t n_ukrs, ... );
|
||||
BLIS_EXPORT_BLIS void bli_cntx_set_l1f_kers( dim_t n_kers, ... );
|
||||
BLIS_EXPORT_BLIS void bli_cntx_set_l1v_kers( dim_t n_kers, ... );
|
||||
BLIS_EXPORT_BLIS void bli_cntx_set_packm_kers( dim_t n_kers, ... );
|
||||
BLIS_EXPORT_BLIS void bli_cntx_set_l3_thresh_funcs( dim_t n_funcs, ... );
|
||||
|
||||
BLIS_EXPORT_BLIS void bli_cntx_print( cntx_t* cntx );
|
||||
|
||||
|
||||
@@ -1432,6 +1432,7 @@ typedef struct cntx_s
|
||||
func_t l3_vir_ukrs[ BLIS_NUM_LEVEL3_UKRS ];
|
||||
func_t l3_nat_ukrs[ BLIS_NUM_LEVEL3_UKRS ];
|
||||
mbool_t l3_nat_ukrs_prefs[ BLIS_NUM_LEVEL3_UKRS ];
|
||||
void* l3_thresh_funcs[ BLIS_NUM_LEVEL3_OPS ];
|
||||
|
||||
blksz_t l3_sup_thresh[ BLIS_NUM_THRESH ];
|
||||
void* l3_sup_handlers[ BLIS_NUM_LEVEL3_OPS ];
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.##
|
||||
|
||||
|
||||
set(SUBDIRECTORIES "1" "1f" "1m" "2" "3")
|
||||
set(SUBDIRECTORIES "1" "1f" "1m" "2" "3" "util")
|
||||
|
||||
#Add all subdirectories
|
||||
foreach(VAR ${SUBDIRECTORIES})
|
||||
|
||||
@@ -273,3 +273,19 @@ void bli_dgemm_ref_k1_nn
|
||||
cntl_t* cntl
|
||||
);
|
||||
|
||||
// threshold functions
|
||||
bool bli_cntx_gemmtsup_thresh_is_met_zen
|
||||
(
|
||||
obj_t* a,
|
||||
obj_t* b,
|
||||
obj_t* c,
|
||||
cntx_t* cntx
|
||||
);
|
||||
|
||||
bool bli_cntx_syrksup_thresh_is_met_zen
|
||||
(
|
||||
obj_t* a,
|
||||
obj_t* b,
|
||||
obj_t* c,
|
||||
cntx_t* cntx
|
||||
);
|
||||
|
||||
6
kernels/zen/util/CMakeLists.txt
Normal file
6
kernels/zen/util/CMakeLists.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.##
|
||||
|
||||
target_sources("${PROJECT_NAME}"
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bli_thresh_funcs_zen.c
|
||||
)
|
||||
88
kernels/zen/util/bli_thresh_funcs_zen.c
Normal file
88
kernels/zen/util/bli_thresh_funcs_zen.c
Normal file
@@ -0,0 +1,88 @@
|
||||
/*
|
||||
|
||||
BLIS
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
- Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
- Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
- Neither the name(s) of the copyright holder(s) nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*/
|
||||
|
||||
#include "blis.h"
|
||||
|
||||
// -- gemmt specific function
|
||||
bool bli_cntx_gemmtsup_thresh_is_met_zen( obj_t* a, obj_t* b, obj_t* c, cntx_t* cntx )
|
||||
{
|
||||
num_t dt = bli_obj_dt( c );
|
||||
|
||||
dim_t n = bli_obj_length( c );
|
||||
dim_t k = bli_obj_width_after_trans( a );
|
||||
|
||||
if( bli_is_double( dt ))
|
||||
{
|
||||
if ( n < 300 ) return TRUE;
|
||||
if ( (k / n ) > 50 ) return TRUE;
|
||||
|
||||
return FALSE;
|
||||
}
|
||||
else if ( bli_is_dcomplex( dt ) )
|
||||
{
|
||||
if ( n < 100 ) return TRUE;
|
||||
else return FALSE;
|
||||
}
|
||||
else
|
||||
return bli_cntx_l3_sup_thresh_is_met( a, b, c, cntx );
|
||||
}
|
||||
|
||||
// -- syrk specific function
|
||||
bool bli_cntx_syrksup_thresh_is_met_zen( obj_t* a, obj_t* b, obj_t* c, cntx_t* cntx )
|
||||
{
|
||||
num_t dt = bli_obj_dt( c );
|
||||
|
||||
dim_t n = bli_obj_length( c );
|
||||
dim_t k = bli_obj_width_after_trans( a );
|
||||
|
||||
stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b );
|
||||
|
||||
if( bli_is_double( dt ) )
|
||||
{
|
||||
if( ( stor_id == BLIS_RRC ) || ( stor_id == BLIS_CCR ) )
|
||||
{
|
||||
if( n < 140) return TRUE;
|
||||
else if( ( n < 200 ) && ( k < 100 ) ) return TRUE;
|
||||
else if( ( n <= 450 ) && ( k < 50 ) ) return TRUE;
|
||||
else return FALSE;
|
||||
}
|
||||
else
|
||||
{
|
||||
if( n < 150 ) return TRUE;
|
||||
else return FALSE;
|
||||
}
|
||||
}
|
||||
else
|
||||
return bli_cntx_l3_sup_thresh_is_met( a, b, c, cntx );
|
||||
}
|
||||
@@ -5,7 +5,7 @@
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2014, The University of Texas at Austin
|
||||
Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -457,6 +457,13 @@ void GENBARNAME(cntx_init)
|
||||
cntx
|
||||
);
|
||||
|
||||
// -- Set level-3 threshold functions -------------------------------------
|
||||
vfuncs = bli_cntx_l3_thresh_funcs_buf( cntx );
|
||||
|
||||
// Initialize all of the function pointers to default function
|
||||
for ( i = 0; i < BLIS_NUM_LEVEL3_OPS; ++i )
|
||||
vfuncs[ i ] = bli_cntx_l3_sup_thresh_is_met;
|
||||
|
||||
|
||||
// -- Set level-3 small/unpacked handlers ----------------------------------
|
||||
|
||||
|
||||
Reference in New Issue
Block a user