Added Parameter Checks and DTL Trace for Extension APIs

1. Added input parameter checking for the extension APIs
   1. gemm_pack_get_size API
   2. gemm_pack API

2. Additionally added early returns for these APIs when
   m or n dimensions are 0.

3. Routines for input parameter check for all the 3
   BLAS extension APIs - gemm_pack_get_size, gemm_pack and
   gemm_compute are defined in:
   frame/compat/check/bla_gemm_pack_compute_check.h

4. Added AOCL DTL TRACE for all the functions of
   1. gemm_pack_get_size
   2. gemm_pack
   3. gemm_compute

AMD-Internal: [CPUPL-3560]
Change-Id: I4351b8494d888eae7e7431a7e1e23e442ffc8631
This commit is contained in:
Eashan Dash
2023-11-07 15:19:29 +05:30
parent 75a4d2f72f
commit e4e4fe55fb
13 changed files with 269 additions and 35 deletions

View File

@@ -45,6 +45,8 @@ void bli_pack_full_init
rntm_t* rntm
)
{
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2);
// Initializing the cntx if one isn't already passed.
if ( cntx == NULL ) {
cntx = bli_gks_query_cntx();
@@ -77,6 +79,8 @@ void bli_pack_full_init
rntm
);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2);
}
// Full pack function for A matrix
@@ -98,6 +102,8 @@ void PASTEMAC(ch,tfuncname) \
thrinfo_t* thread \
) \
{\
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_5); \
\
const num_t dt = PASTEMAC(ch,type); \
\
/* Query the context for various blocksizes. */ \
@@ -191,6 +197,8 @@ void PASTEMAC(ch,tfuncname) \
\
} \
} \
\
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_5); \
\
} \
@@ -217,6 +225,8 @@ void PASTEMAC(ch,tfuncname) \
thrinfo_t* thread \
) \
{ \
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_5); \
\
const num_t dt = PASTEMAC(ch,type); \
\
/* Query the context for various blocksizes. */ \
@@ -354,6 +364,8 @@ void PASTEMAC(ch,tfuncname) \
adjust_B_panel_reordered_jc( &jj, jc_cur_loop ); \
\
} \
\
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_5); \
\
} \
@@ -374,6 +386,8 @@ void PASTEMAC(ch,tfuncname) \
thrinfo_t* thread \
) \
{ \
\
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4); \
\
const num_t dt = bli_obj_dt( src_obj ); \
\
@@ -429,6 +443,8 @@ void PASTEMAC(ch,tfuncname) \
thread \
); \
} \
\
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4); \
\
} \

View File

@@ -45,6 +45,8 @@ void bli_gemm_compute_init
rntm_t* rntm
)
{
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2);
if ( bli_error_checking_is_enabled() )
{
// @todo: Add call to error checking function here
@@ -97,9 +99,11 @@ void bli_gemm_compute_init
cntx,
rntm
);
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2);
}
err_t bli_gemm_compute
void bli_gemm_compute
(
obj_t* a,
obj_t* b,
@@ -110,6 +114,8 @@ err_t bli_gemm_compute
thrinfo_t* thread
)
{
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4);
const num_t dt = bli_obj_dt( c );
const dim_t m = bli_obj_length( c );
const dim_t n = bli_obj_width( c );
@@ -242,7 +248,8 @@ err_t bli_gemm_compute
);
}
return BLIS_SUCCESS;
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4);
}
#undef GENTFUNC
@@ -267,6 +274,8 @@ void PASTEMAC( ch, varname ) \
thrinfo_t* restrict thread \
) \
{ \
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_5); \
\
const num_t dt = PASTEMAC( ch, type ); \
\
/* If m or n is zero, return immediately. */ \
@@ -644,6 +653,9 @@ void PASTEMAC( ch, varname ) \
&mem_b, \
thread_pb \
); \
\
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_5); \
\
}
INSERT_GENTFUNC_BASIC0_SD( gemm_compute )

View File

@@ -42,7 +42,7 @@ void bli_gemm_compute_init
rntm_t* rntm
);
err_t bli_gemm_compute
void bli_gemm_compute
(
obj_t* a,
obj_t* b,

View File

@@ -54,6 +54,8 @@ void sgemm_compute_blis_impl
float* c, const f77_int* rs_c, const f77_int* cs_c
)
{
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1);
trans_t blis_transa;
trans_t blis_transb;
dim_t m0, n0, k0;
@@ -83,11 +85,12 @@ void sgemm_compute_blis_impl
rs_c, cs_c
);
/* Quick return if possible. */
/* Quick return. */
if ( *m == 0 || *n == 0 )
{
/* Finalize BLIS. */
bli_finalize_auto();
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return;
}
@@ -131,6 +134,9 @@ void sgemm_compute_blis_impl
/* Finalize BLIS. */
bli_finalize_auto();
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return;
}
@@ -176,6 +182,8 @@ void dgemm_compute_blis_impl
double* c, const f77_int* rs_c, const f77_int* cs_c
)
{
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1);
trans_t blis_transa;
trans_t blis_transb;
dim_t m0, n0, k0;
@@ -205,11 +213,12 @@ void dgemm_compute_blis_impl
rs_c, cs_c
);
/* Quick return if possible. */
/* Quick return. */
if ( *m == 0 || *n == 0 )
{
/* Finalize BLIS. */
bli_finalize_auto();
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return;
}
@@ -253,6 +262,10 @@ void dgemm_compute_blis_impl
/* Finalize BLIS. */
bli_finalize_auto();
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return;
}
#ifdef BLIS_ENABLE_BLAS

View File

@@ -53,10 +53,36 @@ void sgemm_pack_blis_impl
float* dest
)
{
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1);
dim_t m;
dim_t n;
dim_t k;
bli_init_auto(); // initialize blis
/* Perform BLAS parameter checking. */
PASTEBLACHK(gemm_pack)
(
MKSTR(s),
MKSTR(gemm),
identifier,
trans,
mm,
nn,
kk,
pld
);
/* Quick return. */
if ( *mm == 0 || *nn == 0 )
{
/* Finalize BLIS. */
bli_finalize_auto();
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return;
}
dim_t m0 = 0;
dim_t n0 = 0;
@@ -88,11 +114,6 @@ void sgemm_pack_blis_impl
{
bli_set_dims_with_trans( blis_trans, k, n, &m0, &n0 );
}
else
{
bli_print_msg( " Invalid IDENTIFIER setting sgemm_pack_() .", __FILE__, __LINE__ );
return;
}
bli_obj_init_finish_1x1( dt, (float*)alpha, &alpha_obj );
@@ -102,6 +123,13 @@ void sgemm_pack_blis_impl
bli_obj_set_conjtrans( blis_trans, &src_obj );
bli_pack_full_init(identifier, &alpha_obj, &src_obj, &dest_obj, NULL, NULL);
/* Finalize BLIS. */
bli_finalize_auto();
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return;
}
void sgemm_pack_
@@ -131,10 +159,36 @@ void dgemm_pack_blis_impl
double* dest
)
{
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1);
dim_t m;
dim_t n;
dim_t k;
bli_init_auto(); // initialize blis
/* Perform BLAS parameter checking. */
PASTEBLACHK(gemm_pack)
(
MKSTR(d),
MKSTR(gemm),
identifier,
trans,
mm,
nn,
kk,
pld
);
/* Quick return. */
if ( *mm == 0 || *nn == 0 )
{
/* Finalize BLIS. */
bli_finalize_auto();
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return;
}
dim_t m0 = 0;
dim_t n0 = 0;
@@ -165,11 +219,6 @@ void dgemm_pack_blis_impl
{
bli_set_dims_with_trans( blis_trans, k, n, &m0, &n0 );
}
else
{
bli_print_msg( " Invalid IDENTIFIER setting dgemm_pack_() .", __FILE__, __LINE__ );
return;
}
bli_obj_init_finish_1x1( dt, (double*)alpha, &alpha_obj );
@@ -179,6 +228,13 @@ void dgemm_pack_blis_impl
bli_obj_set_conjtrans( blis_trans, &src_obj );
bli_pack_full_init(identifier, &alpha_obj, &src_obj, &dest_obj, NULL, NULL);
/* Finalize BLIS. */
bli_finalize_auto();
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return;
}
void dgemm_pack_

View File

@@ -55,6 +55,26 @@ f77_int dgemm_pack_get_size_blis_impl
bli_init_auto(); // initialize blis
cntx_t* cntx = bli_gks_query_cntx(); // Get processor specific context.
/* Perform BLAS parameter checking. */
PASTEBLACHK(gemm_get_size)
(
MKSTR(d),
MKSTR(gemm),
identifier,
pm,
pn,
pk
);
/* Quick return. */
if ( *pm == 0 || *pn == 0 )
{
/* Finalize BLIS. */
bli_finalize_auto();
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return 0;
}
num_t dt = BLIS_DOUBLE; // Double precision
f77_int tbytes = 0; // total number of bytes needed for packing.
f77_int m = *pm;
@@ -126,14 +146,12 @@ f77_int dgemm_pack_get_size_blis_impl
tbytes = ps_max * sizeof( double );
}
else
{
bli_print_msg( " Invalid IDENTIFIER setting dgemm_pack_get_size_() .", __FILE__, __LINE__ );
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return tbytes;
}
/* Finalize BLIS. */
bli_finalize_auto();
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return tbytes;
}
@@ -158,9 +176,31 @@ f77_int sgemm_pack_get_size_blis_impl
const f77_int* pk
)
{
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1);
bli_init_auto(); // initialize blis
cntx_t* cntx = bli_gks_query_cntx(); // Get processor specific context.
/* Perform BLAS parameter checking. */
PASTEBLACHK(gemm_get_size)
(
MKSTR(s),
MKSTR(gemm),
identifier,
pm,
pn,
pk
);
/* Quick return. */
if ( *pm == 0 || *pn == 0 )
{
/* Finalize BLIS. */
bli_finalize_auto();
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return 0;
}
num_t dt = BLIS_FLOAT; // Single precision
f77_int tbytes = 0; // total number of bytes needed for packing.
f77_int m = *pm;
@@ -232,11 +272,11 @@ f77_int sgemm_pack_get_size_blis_impl
tbytes = ps_max * sizeof( float );
}
else
{
bli_print_msg( " Invalid IDENTIFIER setting sgemm_pack_get_size_() .", __FILE__, __LINE__ );
return tbytes;
}
/* Finalize BLIS. */
bli_finalize_auto();
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1);
return tbytes;
}

View File

@@ -195,7 +195,7 @@
#include "bla_trmm_check.h"
#include "bla_trsm_check.h"
#include "bla_gemmt_check.h"
#include "bla_gemm_compute_check.h"
#include "bla_gemm_pack_compute_check.h"
// -- Batch Extension prototypes --
#include "bla_gemm_batch.h"

View File

@@ -32,6 +32,89 @@
*/
#define bla_gemm_get_size_check( dt_str, op_str, identifier, m, n, k ) \
{ \
f77_int info = 0; \
f77_int A_identifier, B_identifier; \
\
A_identifier = PASTE_LSAME( identifier, "A", (ftnlen)1, (ftnlen)1 ); \
B_identifier = PASTE_LSAME( identifier, "B", (ftnlen)1, (ftnlen)1 ); \
\
if ( !A_identifier && !B_identifier ) \
info = 1; \
else if ( *m < 0 ) \
info = 2; \
else if ( *n < 0 ) \
info = 3; \
else if ( *k < 0 ) \
info = 4; \
\
if ( info != 0 ) \
{ \
char func_str[ BLIS_MAX_BLAS_FUNC_STR_LENGTH ]; \
\
sprintf( func_str, "%s%-5s", dt_str, op_str ); \
\
bli_string_mkupper( func_str ); \
\
PASTE_XERBLA( func_str, &info, (ftnlen)6 ); \
\
return 0; \
} \
}
#define bla_gemm_pack_check( dt_str, op_str, identifier, trans, m, n, k, pld ) \
{ \
f77_int info = 0; \
f77_int A_identifier, B_identifier; \
f77_int no_trans_param, conj_param, trans_param; \
f77_int nrow; \
\
A_identifier = PASTE_LSAME( identifier, "A", (ftnlen)1, (ftnlen)1 ); \
B_identifier = PASTE_LSAME( identifier, "B", (ftnlen)1, (ftnlen)1 ); \
\
no_trans_param = PASTE_LSAME( trans, "N", (ftnlen)1, (ftnlen)1 ); \
conj_param = PASTE_LSAME( trans, "C", (ftnlen)1, (ftnlen)1 ); \
trans_param = PASTE_LSAME( trans, "T", (ftnlen)1, (ftnlen)1 ); \
\
if ( A_identifier ) \
{ \
if ( no_trans_param ) { nrow = *m; } \
else { nrow = *k; } \
} \
else if ( B_identifier ) \
{ \
if ( no_trans_param ) { nrow = *k; } \
else { nrow = *n; } \
} \
\
if ( !A_identifier && !B_identifier ) \
info = 1; \
else if ( !no_trans_param && !conj_param && !trans_param ) \
info = 2; \
else if ( *m < 0 ) \
info = 3; \
else if ( *n < 0 ) \
info = 4; \
else if ( *k < 0 ) \
info = 5; \
else if ( *pld < bli_max( 1, nrow ) ) \
info = 6; \
\
if ( info != 0 ) \
{ \
char func_str[ BLIS_MAX_BLAS_FUNC_STR_LENGTH ]; \
\
sprintf( func_str, "%s%-5s", dt_str, op_str ); \
\
bli_string_mkupper( func_str ); \
\
PASTE_XERBLA( func_str, &info, (ftnlen)6 ); \
\
return; \
} \
}
#define bla_gemm_compute_check( dt_str, op_str, transa, transb, m, n, k, lda, ldb, rs_c, cs_c ) \
{ \
f77_int info = 0; \
@@ -89,4 +172,4 @@
\
return; \
} \
}
}

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2023, Advanced Micro Devices, Inc.
Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -36,7 +36,7 @@
#define BLIS_L3_COMPUTE_DECOR_H
// Level-3 compute internal function type.
typedef err_t (*l3computeint_t)
typedef void (*l3computeint_t)
(
obj_t* a,
obj_t* b,
@@ -48,7 +48,7 @@ typedef err_t (*l3computeint_t)
);
// Level-3 compute thread decorator prototype.
err_t bli_l3_compute_thread_decorator
void bli_l3_compute_thread_decorator
(
l3computeint_t func,
opid_t family,

View File

@@ -38,7 +38,7 @@
void* bli_l3_compute_thread_entry( void* data_void ) { return NULL; }
err_t bli_l3_compute_thread_decorator
void bli_l3_compute_thread_decorator
(
l3computeint_t func,
opid_t family,
@@ -50,6 +50,8 @@ err_t bli_l3_compute_thread_decorator
rntm_t* rntm
)
{
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3);
// Query the total number of threads from the rntm_t object.
const dim_t n_threads = bli_rntm_num_threads( rntm );
@@ -123,7 +125,8 @@ err_t bli_l3_compute_thread_decorator
// mutual exclusion.
bli_sba_checkin_array( array );
return BLIS_SUCCESS;
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3);
}
#endif

View File

@@ -36,7 +36,7 @@
#if !defined (BLIS_ENABLE_MULTITHREADING) || defined (BLIS_ENABLE_PTHREADS)
err_t bli_l3_compute_thread_decorator
void bli_l3_compute_thread_decorator
(
l3computeint_t func,
opid_t family,
@@ -48,6 +48,8 @@ err_t bli_l3_compute_thread_decorator
rntm_t* rntm
)
{
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3);
const dim_t n_threads = 1;
array_t* restrict array = bli_sba_checkout_array( n_threads );
bli_sba_rntm_set_pool( 0, array, rntm );
@@ -81,7 +83,8 @@ err_t bli_l3_compute_thread_decorator
bli_sba_checkin_array( array );
return BLIS_SUCCESS;
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3);
}
#endif

View File

@@ -49,6 +49,8 @@ void bli_pack_full_thread_decorator
rntm_t* rntm
)
{
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3);
dim_t n_threads = bli_rntm_num_threads( rntm );
/* Ensure n_threads is always greater than or equal to 1 */
@@ -76,6 +78,8 @@ void bli_pack_full_thread_decorator
&thread
);
}
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3);
}
#endif

View File

@@ -49,6 +49,8 @@ void bli_pack_full_thread_decorator
rntm_t* rntm
)
{
AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3);
thrinfo_t thread = BLIS_GEMM_SINGLE_THREADED;
{
@@ -66,6 +68,8 @@ void bli_pack_full_thread_decorator
);
}
AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3);
}
#endif