Files
blis/bench/bench_aocl_gemm/bench_batch_lpgemm.c
Meghana Vankadari 69ca5dbcd6 Fixed compilation errors for gcc versions < 11.2
Details:
- Disabled intrinsics code of f32obf16 pack function
  for gcc < 11.2 as the instructions used in kernels
  are not supported by the compiler versions.
- Addded early-return check for WOQ APIs when compiling with
  gcc < 11.2
- Fixed code to check whether JIT kernels are generated inside
  batch_gemm API for bf16 datatype.

AMD Internal: [CPUPL-6327]

Change-Id: I0a017c67eb9d9d22a14e095e435dc397e265fb0a
2025-01-21 07:13:31 -05:00

1607 lines
59 KiB
C

/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name(s) of the copyright holder(s) nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "bench_lpgemm_helpers.h"
#define POST_OPS_STR_LEN 104
CONVERT_TO_FLOAT(uint8_t)
CONVERT_TO_FLOAT(int8_t)
CONVERT_TO_FLOAT(int16_t)
CONVERT_TO_FLOAT(float)
CONVERT_TO_FLOAT(int32_t)
PRINT_MATRIX(uint8_t)
PRINT_MATRIX(int8_t)
PRINT_MATRIX(int16_t)
PRINT_MATRIX(float)
PRINT_MATRIX(int32_t)
GEN_FILL_ARRAY_FUNC(int8_t)
GEN_FILL_ARRAY_FUNC(int16_t)
GEN_FILL_ARRAY_FUNC(float)
GEN_FILL_ARRAY_FUNC(int32_t)
void fill_array_uint8_t ( void* arr, dim_t size )
{
if( size < 0 ) return;
uint8_t* temp_arr = ( uint8_t* ) arr;
for ( dim_t i = 0; i < size; ++i )
{
temp_arr[i] = ( uint8_t )( i % 5 );
}
}
void fill_array_int4_c_t( void* arr, dim_t size )
{
int8_t int4_c_t_values[8] = { 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF };
//int8_t int4_c_t_values[8] = { 0x01, 0x23, 0x45, 0x67, 0x01, 0x23, 0x45, 0x67 };
dim_t int4_c_t_size = ( size + 1 ) / 2;
if ( size < 0 ) return;
// Fill in pairs for in4_t since 4 bits/half byte access is not
// straight forward.
int8_t* temp_arr = ( int8_t* )arr;
for (dim_t i = 0; i < int4_c_t_size; ++i)
{
temp_arr[i] = int4_c_t_values[( i % 8 )];
}
}
#define GEN_BLIS_MAT_MUL_FUNC(A_type,B_type,C_type,ACCUM_type,BLAS_SFX) \
void mat_mul_ ## BLAS_SFX \
( \
char* stor_order, \
char* transa, \
char* transb, \
char* op_a, \
char* op_b, \
dim_t bs, \
dim_t* m, \
dim_t* n, \
dim_t* k, \
ACCUM_type* alpha, \
A_type** a, \
dim_t* lda, \
B_type** b, \
dim_t* ldb, \
ACCUM_type* beta, \
C_type** c, \
dim_t* ldc, \
aocl_post_op** post_op\
) \
{ \
aocl_batch_gemm_ ## BLAS_SFX( stor_order, transa, transb, bs, m, n, k, \
alpha, \
(const A_type**)a, lda, op_a, \
(const B_type**)b, ldb, op_b, \
beta, \
c, ldc, post_op ); \
} \
GEN_BLIS_MAT_MUL_FUNC(bfloat16,bfloat16,float,float,bf16bf16f32of32)
GEN_BLIS_MAT_MUL_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
GEN_BLIS_MAT_MUL_FUNC(float,float,float,float,f32f32f32of32)
GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32)
GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8)
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32)
GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
GEN_BLIS_MAT_MUL_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32)
GEN_BLIS_MAT_MUL_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
double get_gflops
(
dim_t m,
dim_t n,
dim_t k,
double runtime
)
{
return ( ( 2.0 * m * n * k ) / ( runtime * 1.0e9 ) );
}
void print_result
(
const char* msg,
int32_t n_repeats,
char* stor_order,
char* transa,
char* transb,
char* op_a,
char* op_b,
dim_t bs,
dim_t* m,
dim_t* n,
dim_t* k,
dim_t* lda,
dim_t* ldb,
dim_t* ldc,
double gflops
)
{
printf( "%s bs: %ld, stor:%c, transa:%c, transb:%c, op_a:%c, op_b:%c, m: %ld, n: %ld, k: %ld, lda: %ld," \
" ldb: %ld, ldc: %ld, n_repeats: %d, Gops: %f \n", \
msg, bs, stor_order[0], transa[0], transb[0], op_a[0], op_b[0], m[0], n[0], k[0], lda[0], ldb[0], \
ldc[0], n_repeats, gflops );
if( bench_mode == 'a' )
{
for( dim_t i = 1; i < bs; i++)
{
printf("stor:%c, transa:%c, transb:%c, op_a:%c, op_b:%c, m: %ld, n: %ld, k: %ld, lda: %ld, ldb: %ld, ldc: %ld\n",
stor_order[i], transa[i], transb[i], op_a[i], op_b[i], m[i], n[i], k[i], lda[i], ldb[i], ldc[i]);
}
}
}
#define GEN_MAT_MUL_BENCH_DRV_FUNC(A_type,B_type,C_type,ACCUM_type,BLAS_SFX) \
void mat_mul_bench_driver_ ## BLAS_SFX \
( \
char* stor_order, \
char* transa, \
char* transb, \
char* op_a, \
char* op_b, \
int32_t n_repeats, \
dim_t bs, \
dim_t* m, \
dim_t* n, \
dim_t* k, \
ACCUM_type* alpha, \
A_type** a, \
dim_t* lda, \
B_type** b, \
dim_t* ldb, \
ACCUM_type* beta, \
C_type** c, \
dim_t* ldc, \
aocl_post_op** post_op\
) \
{ \
double dtime; \
double dtime_save = DBL_MAX; \
\
for ( int32_t nr = 0; nr < n_repeats; ++nr ) \
{ \
dtime = bli_clock(); \
\
GEN_FUNC_NAME(mat_mul_,BLAS_SFX) \
( \
stor_order, transa, transb, op_a, op_b, bs, m, n, k, \
alpha, \
a, lda, \
b, ldb, \
beta, \
c, ldc, \
post_op \
); \
\
dtime_save = bli_clock_min_diff( dtime_save, dtime ); \
\
} \
double ops = 0; \
for( dim_t i = 0; i < bs; i++ ) { ops += 2.0 * m[i] * n[i] * k[i];} \
double gflops = ( ops ) / ( dtime_save * 1.0e9 ); \
\
print_result( XSTR(BLAS_SFX), n_repeats, stor_order, transa, transb, op_a, op_b, bs, m, n, k, lda, ldb, ldc, gflops); \
} \
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,bfloat16,float,float,bf16bf16f32of32)
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16)
GEN_MAT_MUL_BENCH_DRV_FUNC(float,float,float,float,f32f32f32of32)
GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32)
GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8)
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32)
GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32)
GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16)
#define GEN_MAT_MUL_ACC_CHK_DOWNSCALE(C_type,ACCUM_type,SCALE_type,BLAS_DOWNSCALE_SFX) \
static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX \
(\
ACCUM_type temp_accum,\
aocl_post_op* post_op, \
dim_t j \
)\
{ \
dim_t j_scale = j; \
if ( ( post_op->sum )->scale_factor_len == 1 ) \
{ \
j_scale = 0; \
} \
\
dim_t j_zp = j; \
if ( ( post_op->sum )->zero_point_len == 1 ) \
{ \
j_zp = 0; \
} \
\
ACCUM_type out_temp_accum = \
( ACCUM_type )min( \
max( nearbyintf( ( SCALE_type )( temp_accum ) * \
( *( ( SCALE_type* )( post_op->sum )->scale_factor + j_scale ) ) ) + \
*( ( C_type* )( post_op->sum )->zero_point + j_zp ), \
DSCALE_CLIP_MIN ), \
DSCALE_CLIP_MAX ); \
return out_temp_accum; \
}\
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int16_t,float,u8s8s16os8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(uint8_t,int16_t,float,u8s8s16ou8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,u8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,s8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int16_t,float,s8s8s16os8)
static inline float mat_mul_accuracy_check_downscale_bf16bf16f32obf16
(
float temp_accum,
aocl_post_op* post_op,
dim_t j
)
{
dim_t j_scale = j;
if ( ( post_op->sum )->scale_factor_len == 1 )
{
j_scale = 0;
}
dim_t j_zp = j;
if ( ( post_op->sum )->zero_point_len == 1 )
{
j_zp = 0;
}
float zp_float = 0.0;
bfloat16_to_float( *( ( bfloat16* )( post_op->sum )->zero_point + j_zp ),
&zp_float );
float out_temp_accum = ( temp_accum *
( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) +
zp_float );
return out_temp_accum;
}
static inline float mat_mul_accuracy_check_downscale_f32f32f32of32
(
float temp_accum,
aocl_post_op* post_op,
dim_t j
)
{
dim_t j_scale = j;
if ( ( post_op->sum )->scale_factor_len == 1 )
{
j_scale = 0;
}
dim_t j_zp = j;
if ( ( post_op->sum )->zero_point_len == 1 )
{
j_zp = 0;
}
float out_temp_accum = ( temp_accum *
( *( ( float* )( post_op->sum )->scale_factor + j_scale ) ) +
*( ( float* )( post_op->sum )->zero_point + j_zp ) );
return out_temp_accum;
}
#define GEN_MAT_MUL_ACC_CHK_ACCUM(A_type, B_type, C_type,ACCUM_type,BLAS_SFX) \
static inline ACCUM_type mat_mul_accuracy_check_accum_ ## BLAS_SFX \
(\
A_type* a, \
B_type* b, \
C_type* c_ref, \
ACCUM_type temp_accum,\
ACCUM_type alpha, \
ACCUM_type beta, \
dim_t rs_a, \
dim_t rs_b, \
dim_t cs_a, \
dim_t cs_b, \
dim_t rs_c_ref, \
dim_t cs_c_ref, \
dim_t i, \
dim_t j, \
dim_t k, \
bool int4_testing, /* Workaround to enable int4 B matrix testing. */\
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
) \
{ \
( void )int4_testing; \
( void ) pre_op; \
for ( dim_t p = 0; p < k; ++p) \
{ \
temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) * \
*( b + ( rs_b * p ) + ( cs_b * j ) ) ); \
} \
\
temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) ) \
+ ( alpha * temp_accum ); \
return temp_accum; \
} \
GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8)
GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,uint8_t,int16_t,u8s8s16ou8)
GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16)
GEN_MAT_MUL_ACC_CHK_ACCUM(float,float,float,float,f32f32f32of32)
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int8_t,int32_t,s8s8s32os8)
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int32_t,int32_t,s8s8s32os32)
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int8_t,int16_t,s8s8s16os8)
GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int16_t,int16_t,s8s8s16os16)
#define GEN_MAT_MUL_ACC_CHK_ACCUM_INT4(A_type, B_type, C_type,ACCUM_type,BLAS_SFX) \
static inline ACCUM_type mat_mul_accuracy_check_accum_ ## BLAS_SFX \
(\
A_type* a, \
B_type* b, \
C_type* c_ref, \
ACCUM_type temp_accum,\
ACCUM_type alpha, \
ACCUM_type beta, \
dim_t rs_a, \
dim_t rs_b, \
dim_t cs_a, \
dim_t cs_b, \
dim_t rs_c_ref, \
dim_t cs_c_ref, \
dim_t i, \
dim_t j, \
dim_t k, \
bool int4_testing, /* Workaround to enable int4 B matrix testing. */\
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
) \
{ \
( void ) pre_op; \
if ( int4_testing == FALSE ) \
{ \
for ( dim_t p = 0; p < k; ++p) \
{ \
temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) * \
*( b + ( rs_b * p ) + ( cs_b * j ) ) ); \
} \
} \
else \
{ \
for ( dim_t p = 0; p < k; ++p) \
{ \
/* Get B matrix int4_t value and upscale it to int8_t. */ \
dim_t b_inc = ( rs_b * p ) + ( cs_b * j ); \
int8_t b_val = 0; \
/* Even index will have data at low 4 bits, and odd at hi 4 bits.
* B matrix increments has to be halved to account for 4 bit
* traversal. */ \
if ( ( b_inc % 2 ) != 0 ) \
{ \
b_val = ( ( *( b + ( b_inc / 2 ) ) ) >> 4 ) & 0x0F; \
} \
else \
{ \
b_val = ( *( b + ( b_inc / 2 ) ) ) & 0x0F; \
} \
/* Signed scale. */ \
if ( b_val & 0x08 ) \
{ \
b_val = b_val | 0xF0; \
} \
temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) * b_val ); \
} \
} \
\
temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) ) \
+ ( alpha * temp_accum ); \
return temp_accum; \
} \
GEN_MAT_MUL_ACC_CHK_ACCUM_INT4(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8)
GEN_MAT_MUL_ACC_CHK_ACCUM_INT4(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32)
static inline float mat_mul_accuracy_check_accum_bf16bf16f32of32
(
bfloat16* a,
bfloat16* b,
float* c_ref,
float temp_accum,
float alpha,
float beta,
dim_t rs_a,
dim_t rs_b,
dim_t cs_a,
dim_t cs_b,
dim_t rs_c_ref,
dim_t cs_c_ref,
dim_t i,
dim_t j,
dim_t k,
bool int4_testing, /* Ignored for bf16 testing */\
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
)
{
( void )int4_testing;
( void ) pre_op;
for ( dim_t p = 0; p < k; ++p)
{
float a_float, b_float;
bfloat16_to_float( *( a + i * rs_a + p * cs_a ) , &a_float);
bfloat16_to_float( *( b + p * rs_b + j * cs_b ) , &b_float);
temp_accum += ( ( a_float ) * ( b_float ) );
}
temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) )
+ ( alpha * temp_accum );
return temp_accum;
}
static inline float mat_mul_accuracy_check_accum_bf16bf16f32obf16
(
bfloat16* a,
bfloat16* b,
bfloat16* c_ref,
float temp_accum,
float alpha,
float beta,
dim_t rs_a,
dim_t rs_b,
dim_t cs_a,
dim_t cs_b,
dim_t rs_c_ref,
dim_t cs_c_ref,
dim_t i,
dim_t j,
dim_t k,
bool int4_testing, /* Ignored for bf16 testing */\
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
)
{
( void )int4_testing;
( void ) pre_op;
for ( dim_t p = 0; p < k; ++p)
{
float a_float, b_float;
bfloat16_to_float( *( a + i*rs_a + p*cs_a ), &a_float );
bfloat16_to_float( *( b + p*rs_b + j*cs_b ), &b_float );
temp_accum += ( ( a_float ) * ( b_float ) );
}
float c_ref_float;
bfloat16_to_float( *( c_ref + i*rs_c_ref + j*cs_c_ref ), &c_ref_float );
temp_accum = ( beta * ( c_ref_float ) ) + ( alpha * temp_accum );
return temp_accum;
}
static inline float get_s4_to_f32_scale_val
(
int8_t* b,
dim_t j,
dim_t b_inc,
aocl_pre_op* pre_op
)
{
float b_float = 0.0;
int8_t b_val = 0;
/* Even index will have data at low 4 bits, and odd at hi 4 bits.
* B matrix increments has to be halved to account for 4 bit
* traversal. */
if ( ( b_inc % 2 ) != 0 )
{
b_val = ( ( *( b + ( b_inc / 2 ) ) ) >> 4 ) & 0x0F;
}
else
{
b_val = ( *( b + ( b_inc / 2 ) ) ) & 0x0F;
}
/* Signed scale. */
if ( b_val & 0x08 )
{
b_val = b_val | 0xF0;
}
if ( ( pre_op != NULL ) && ( pre_op->seq_length > 0 ) )
{
dim_t j_zp = j;
if ( ( pre_op->b_zp != NULL ) &&
( ( pre_op->b_zp )->zero_point_len == 1 ) )
{
j_zp = 0;
}
dim_t j_scale = j;
if ( ( pre_op->b_scl != NULL ) &&
( ( pre_op->b_scl )->scale_factor_len == 1 ) )
{
j_scale = 0;
}
// Assuming only 1 scale and zp.
int8_t zp = 0;
if ( ( pre_op->b_zp != NULL ) &&
( ( pre_op->b_zp )->zero_point != NULL ) )
{
zp = *( ( int8_t* )( pre_op->b_zp )->zero_point + j_zp );
}
float scale_factor = 1.0;
if ( ( pre_op->b_scl != NULL ) &&
( ( pre_op->b_scl )->scale_factor != NULL ) )
{
scale_factor = *( ( float* )( pre_op->b_scl )->scale_factor + j_scale );
}
b_float = (float)( b_val - zp ) * scale_factor;
}
else
{
b_float = (float)( b_val);
}
return b_float;
}
static inline float mat_mul_accuracy_check_accum_bf16s4f32of32
(
bfloat16* a,
int8_t* b,
float* c_ref,
float temp_accum,
float alpha,
float beta,
dim_t rs_a,
dim_t rs_b,
dim_t cs_a,
dim_t cs_b,
dim_t rs_c_ref,
dim_t cs_c_ref,
dim_t i,
dim_t j,
dim_t k,
bool int4_testing, /* Ignored s4 implies int4 testing. */\
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
)
{
( void )int4_testing;
for ( dim_t p = 0; p < k; ++p)
{
float a_float, b_float;
bfloat16_to_float( *( a + i * rs_a + p * cs_a ) , &a_float);
/* Get B matrix int4_t value and upscale it to float. */
dim_t b_inc = ( rs_b * p ) + ( cs_b * j );
b_float = get_s4_to_f32_scale_val( b, j, b_inc, pre_op );
temp_accum += ( ( a_float ) * ( b_float ) );
}
temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) )
+ ( alpha * temp_accum );
return temp_accum;
}
static inline float mat_mul_accuracy_check_accum_bf16s4f32obf16
(
bfloat16* a,
int8_t* b,
bfloat16* c_ref,
float temp_accum,
float alpha,
float beta,
dim_t rs_a,
dim_t rs_b,
dim_t cs_a,
dim_t cs_b,
dim_t rs_c_ref,
dim_t cs_c_ref,
dim_t i,
dim_t j,
dim_t k,
bool int4_testing, /* Ignored for bf16 testing */\
aocl_pre_op* pre_op /* Workaround to enable B pre-ops. */ \
)
{
( void )int4_testing;
for ( dim_t p = 0; p < k; ++p)
{
float a_float, b_float;
bfloat16_to_float( *( a + i*rs_a + p*cs_a ), &a_float );
/* Get B matrix int4_t value and upscale it to float. */
dim_t b_inc = ( rs_b * p ) + ( cs_b * j );
b_float = get_s4_to_f32_scale_val( b, j, b_inc, pre_op );
temp_accum += ( ( a_float ) * ( b_float ) );
}
float c_ref_float;
bfloat16_to_float( *( c_ref + i*rs_c_ref + j*cs_c_ref ), &c_ref_float );
temp_accum = ( beta * ( c_ref_float ) ) + ( alpha * temp_accum );
return temp_accum;
}
GEN_GELU_TANH_POSTOP_INT(int16_t,u8s8s16os8)
GEN_GELU_TANH_POSTOP_INT(int16_t,u8s8s16ou8)
GEN_GELU_TANH_POSTOP_INT(int16_t,u8s8s16os16)
GEN_GELU_TANH_POSTOP_INT(int32_t,u8s8s32os8)
GEN_GELU_TANH_POSTOP_INT(int32_t,u8s8s32os32)
GEN_GELU_TANH_POSTOP_INT(int32_t,s8s8s32os8)
GEN_GELU_TANH_POSTOP_INT(int32_t,s8s8s32os32)
GEN_GELU_TANH_POSTOP_INT(int16_t,s8s8s16os8)
GEN_GELU_TANH_POSTOP_INT(int16_t,s8s8s16os16)
GEN_GELU_TANH_POSTOP_FLOAT(f32f32f32of32)
GEN_GELU_TANH_POSTOP_FLOAT(bf16bf16f32of32)
GEN_GELU_TANH_POSTOP_FLOAT(bf16bf16f32obf16)
GEN_GELU_TANH_POSTOP_FLOAT(bf16s4f32of32)
GEN_GELU_TANH_POSTOP_FLOAT(bf16s4f32obf16)
GEN_TANH_POSTOP_INT(int16_t,u8s8s16os8)
GEN_TANH_POSTOP_INT(int16_t,u8s8s16ou8)
GEN_TANH_POSTOP_INT(int16_t,u8s8s16os16)
GEN_TANH_POSTOP_INT(int32_t,u8s8s32os8)
GEN_TANH_POSTOP_INT(int32_t,u8s8s32os32)
GEN_TANH_POSTOP_INT(int32_t,s8s8s32os8)
GEN_TANH_POSTOP_INT(int32_t,s8s8s32os32)
GEN_TANH_POSTOP_INT(int16_t,s8s8s16os8)
GEN_TANH_POSTOP_INT(int16_t,s8s8s16os16)
GEN_TANH_POSTOP_FLOAT(f32f32f32of32)
GEN_TANH_POSTOP_FLOAT(bf16bf16f32of32)
GEN_TANH_POSTOP_FLOAT(bf16bf16f32obf16)
GEN_TANH_POSTOP_FLOAT(bf16s4f32of32)
GEN_TANH_POSTOP_FLOAT(bf16s4f32obf16)
GEN_GELU_ERF_POSTOP_INT(int16_t,u8s8s16os8)
GEN_GELU_ERF_POSTOP_INT(int16_t,u8s8s16ou8)
GEN_GELU_ERF_POSTOP_INT(int16_t,u8s8s16os16)
GEN_GELU_ERF_POSTOP_INT(int32_t,u8s8s32os8)
GEN_GELU_ERF_POSTOP_INT(int32_t,u8s8s32os32)
GEN_GELU_ERF_POSTOP_INT(int32_t,s8s8s32os8)
GEN_GELU_ERF_POSTOP_INT(int32_t,s8s8s32os32)
GEN_GELU_ERF_POSTOP_INT(int16_t,s8s8s16os8)
GEN_GELU_ERF_POSTOP_INT(int16_t,s8s8s16os16)
GEN_GELU_ERF_POSTOP_FLOAT(f32f32f32of32)
GEN_GELU_ERF_POSTOP_FLOAT(bf16bf16f32of32)
GEN_GELU_ERF_POSTOP_FLOAT(bf16bf16f32obf16)
GEN_GELU_ERF_POSTOP_FLOAT(bf16s4f32of32)
GEN_GELU_ERF_POSTOP_FLOAT(bf16s4f32obf16)
GEN_SWISH_POSTOP_INT(int16_t,u8s8s16os8)
GEN_SWISH_POSTOP_INT(int16_t,u8s8s16ou8)
GEN_SWISH_POSTOP_INT(int16_t,u8s8s16os16)
GEN_SWISH_POSTOP_INT(int32_t,u8s8s32os8)
GEN_SWISH_POSTOP_INT(int32_t,u8s8s32os32)
GEN_SWISH_POSTOP_INT(int32_t,s8s8s32os8)
GEN_SWISH_POSTOP_INT(int32_t,s8s8s32os32)
GEN_SWISH_POSTOP_INT(int16_t,s8s8s16os8)
GEN_SWISH_POSTOP_INT(int16_t,s8s8s16os16)
GEN_SWISH_POSTOP_FLOAT(f32f32f32of32)
GEN_SWISH_POSTOP_FLOAT(bf16bf16f32of32)
GEN_SWISH_POSTOP_FLOAT(bf16bf16f32obf16)
GEN_SWISH_POSTOP_FLOAT(bf16s4f32of32)
GEN_SWISH_POSTOP_FLOAT(bf16s4f32obf16)
GEN_SIGMOID_POSTOP_INT(int16_t,u8s8s16os8)
GEN_SIGMOID_POSTOP_INT(int16_t,u8s8s16ou8)
GEN_SIGMOID_POSTOP_INT(int16_t,u8s8s16os16)
GEN_SIGMOID_POSTOP_INT(int32_t,u8s8s32os8)
GEN_SIGMOID_POSTOP_INT(int32_t,u8s8s32os32)
GEN_SIGMOID_POSTOP_INT(int32_t,s8s8s32os8)
GEN_SIGMOID_POSTOP_INT(int32_t,s8s8s32os32)
GEN_SIGMOID_POSTOP_INT(int16_t,s8s8s16os8)
GEN_SIGMOID_POSTOP_INT(int16_t,s8s8s16os16)
GEN_SIGMOID_POSTOP_FLOAT(f32f32f32of32)
GEN_SIGMOID_POSTOP_FLOAT(bf16bf16f32of32)
GEN_SIGMOID_POSTOP_FLOAT(bf16bf16f32obf16)
GEN_SIGMOID_POSTOP_FLOAT(bf16s4f32of32)
GEN_SIGMOID_POSTOP_FLOAT(bf16s4f32obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL_BF16(bfloat16,bf16bf16f32obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL_BF16(bfloat16,bf16s4f32obf16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int32_t,u8s8s32os8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,int32_t,u8s8s32os32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int16_t,u8s8s16os8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(uint8_t,int16_t,u8s8s16ou8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int16_t,int16_t,u8s8s16os16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int32_t,s8s8s32os8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int32_t,int32_t,s8s8s32os32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int8_t,int16_t,s8s8s16os8)
GEN_GET_MATRIX_ADD_POST_OP_VAL(int16_t,int16_t,s8s8s16os16)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,f32f32f32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,bf16bf16f32of32)
GEN_GET_MATRIX_ADD_POST_OP_VAL(float,float,bf16s4f32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bfloat16,bf16bf16f32obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL_BF16(bfloat16,bf16s4f32obf16)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int32_t,u8s8s32os8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,int32_t,u8s8s32os32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int16_t,u8s8s16os8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(uint8_t,int16_t,u8s8s16ou8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int16_t,int16_t,u8s8s16os16)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int32_t,s8s8s32os8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int32_t,int32_t,s8s8s32os32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int8_t,int16_t,s8s8s16os8)
GEN_GET_MATRIX_MUL_POST_OP_VAL(int16_t,int16_t,s8s8s16os16)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,f32f32f32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,bf16bf16f32of32)
GEN_GET_MATRIX_MUL_POST_OP_VAL(float,float,bf16s4f32of32)
GEN_GET_BIAS_POST_OP_VAL_BF16(bf16bf16f32obf16)
GEN_GET_BIAS_POST_OP_VAL_BF16(bf16s4f32obf16)
GEN_GET_BIAS_POST_OP_VAL(int32_t,u8s8s32os8)
GEN_GET_BIAS_POST_OP_VAL(int32_t,u8s8s32os32)
GEN_GET_BIAS_POST_OP_VAL(int16_t,u8s8s16os8)
GEN_GET_BIAS_POST_OP_VAL(int16_t,u8s8s16ou8)
GEN_GET_BIAS_POST_OP_VAL(int16_t,u8s8s16os16)
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os8)
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os32)
GEN_GET_BIAS_POST_OP_VAL(int16_t,s8s8s16os8)
GEN_GET_BIAS_POST_OP_VAL(int16_t,s8s8s16os16)
GEN_GET_BIAS_POST_OP_VAL_f32(f32f32f32of32)
GEN_GET_BIAS_POST_OP_VAL_f32(bf16bf16f32of32)
GEN_GET_BIAS_POST_OP_VAL_f32(bf16s4f32of32)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int32_t,int32_t)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int8_t,int32_t)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int16_t,int16_t)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int8_t,int16_t)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(uint8_t,int16_t)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(float,float)
#define GEN_MAT_MUL_ACC_CHK_DRV_FUNC(A_type,B_type,C_type,ACCUM_type,SCALE_type,BLAS_SFX,BLAS_DOWNSCALE_SFX) \
void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
( \
FILE* fout, \
const char* stor_order, \
char* transa, \
char* transb, \
dim_t bs, \
dim_t* m, \
dim_t* n, \
dim_t* k, \
ACCUM_type* alpha, \
A_type** a, \
dim_t* lda, \
B_type** b, \
dim_t* ldb, \
ACCUM_type* beta, \
C_type** c, \
dim_t* ldc, \
C_type** c_ref, \
dim_t* ldc_ref, \
aocl_post_op** post_op, \
bool int4_testing /* Workaround to enable int4 B matrix testing. */ \
) \
{ \
dim_t rs_a, cs_a; \
dim_t rs_b, cs_b; \
dim_t rs_c, cs_c; \
dim_t rs_c_ref; \
dim_t cs_c_ref; \
for( dim_t bs_i = 0; bs_i < bs; bs_i++ ) \
{ \
if( stor_order[bs_i] == 'r' || stor_order[bs_i] == 'R' ) \
{ \
if( ( transa[bs_i] == 'n' ) || ( transa[bs_i] == 'N' ) ) \
{ \
rs_a = lda[bs_i]; \
cs_a = 1; \
} \
else \
{ \
rs_a = 1; \
cs_a = lda[bs_i]; \
} \
if( ( transb[bs_i] == 'n' ) || ( transb[bs_i] == 'N' ) ) \
{ \
rs_b = ldb[bs_i]; \
cs_b = 1; \
} \
else \
{ \
rs_b = 1; \
cs_b = ldb[bs_i]; \
} \
rs_c = ldc[bs_i]; \
cs_c = 1; \
rs_c_ref = ldc_ref[bs_i]; \
cs_c_ref = 1; \
} \
else /* column storage */ \
{ \
if( transa[bs_i] == 'n' || transa[bs_i] == 'N') \
{ \
rs_a = 1; \
cs_a = lda[bs_i]; \
} \
else \
{ \
rs_a= lda[bs_i]; \
cs_a = 1; \
} \
if( ( transb[bs_i] == 'n' ) || ( transb[bs_i] == 'N' ) ) \
{ \
rs_b = 1; \
cs_b = ldb[bs_i]; \
} \
else \
{ \
rs_b = ldb[bs_i]; \
cs_b = 1; \
} \
rs_c = 1; \
cs_c = ldc[bs_i]; \
rs_c_ref = 1; \
cs_c_ref = ldc_ref[bs_i]; \
} \
aocl_pre_op* a_pre_op = NULL; \
if ( post_op[bs_i] != NULL ) \
{ \
a_pre_op = post_op[bs_i]->pre_ops; \
} \
for ( dim_t i = 0; i < m[bs_i]; ++i ) \
{ \
for ( dim_t j = 0; j < n[bs_i]; ++j ) \
{ \
ACCUM_type temp_accum = 0; \
C_type out_temp_accum = 0; \
\
temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_accum_,BLAS_SFX) \
(a[bs_i], b[bs_i], c_ref[bs_i], temp_accum, alpha[bs_i], beta[bs_i],\
rs_a, rs_b, cs_a, cs_b, rs_c_ref, cs_c_ref, i, j, k[bs_i], \
int4_testing, a_pre_op); \
\
if ( post_op[bs_i] != NULL ) \
{ \
dim_t ele_i = 0; \
for ( dim_t op_id = 0; op_id < post_op[bs_i]->seq_length; ++op_id ) \
{ \
if ( post_op[bs_i]->seq_vector[op_id] == BIAS ) \
{ \
temp_accum += GEN_FUNC_NAME(get_bias_post_op_val_,BLAS_SFX) \
( ( post_op[bs_i]->bias )->bias, j, ( post_op[bs_i]->bias )->stor_type ); \
} \
else if ( post_op[bs_i]->seq_vector[op_id] == ELTWISE ) \
{ \
if ( ( post_op[bs_i]->eltwise + ele_i )->algo.algo_type == \
PRELU ) /* PReLU*/ \
{ \
temp_accum = ( temp_accum > 0 ) ? \
temp_accum : \
( temp_accum * \
*( ( ACCUM_type* ) ( post_op[bs_i]->eltwise + ele_i )->algo.alpha ) ); \
ele_i += 1; \
} \
else if ( ( post_op[bs_i]->eltwise + ele_i )->algo.algo_type == \
GELU_TANH ) /* TANH GeLU*/ \
{ \
temp_accum = GEN_FUNC_NAME(GELU_TANH_post_op_,BLAS_SFX) (temp_accum);\
ele_i += 1; \
} \
else if ( ( post_op[bs_i]->eltwise + ele_i )->algo.algo_type == \
GELU_ERF ) /* ERF GeLU*/ \
{ \
temp_accum = GEN_FUNC_NAME(GELU_ERF_post_op_,BLAS_SFX) (temp_accum);\
ele_i += 1; \
} \
else if ( ( post_op[bs_i]->eltwise + ele_i )->algo.algo_type == \
SWISH ) /* SiLU*/ \
{ \
temp_accum = GEN_FUNC_NAME(SWISH_post_op_,BLAS_SFX) \
(temp_accum, \
*( ( ACCUM_type* ) \
( post_op[bs_i]->eltwise + ele_i )->algo.alpha ) );\
ele_i += 1; \
} \
else if ( ( post_op[bs_i]->eltwise + ele_i )->algo.algo_type == \
RELU ) /* ReLU*/ \
{ \
temp_accum = ( temp_accum > 0 ) ? temp_accum : 0 ; \
ele_i += 1; \
} \
else if ( ( post_op[bs_i]->eltwise + ele_i )->algo.algo_type == \
TANH ) /* TANH*/ \
{ \
temp_accum = GEN_FUNC_NAME(TANH_post_op_,BLAS_SFX) (temp_accum);\
ele_i += 1; \
} \
else if ( ( post_op[bs_i]->eltwise + ele_i )->algo.algo_type == \
SIGMOID ) /* Sigmoid*/ \
{ \
temp_accum = GEN_FUNC_NAME(SIGMOID_post_op_,BLAS_SFX) (temp_accum);\
ele_i += 1; \
} \
else if ( ( post_op[bs_i]->eltwise + ele_i )->algo.algo_type == \
CLIP ) /* CLIP*/ \
{ \
temp_accum = \
min \
( \
max \
( \
temp_accum, \
*( ( ACCUM_type* ) \
( post_op[bs_i]->eltwise + ele_i )->algo.alpha ) \
), \
*( ( ACCUM_type* ) \
( post_op[bs_i]->eltwise + ele_i )->algo.beta) \
); \
ele_i += 1; \
} \
else \
{} \
} \
else if ( post_op[bs_i]->seq_vector[op_id] == SCALE ) \
{ \
temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_DOWNSCALE_SFX) \
(temp_accum, post_op[bs_i], j); \
} \
else if ( post_op[bs_i]->seq_vector[op_id] == MATRIX_ADD ) \
{ \
dim_t rs_m = ( post_op[bs_i]->matrix_add )->ldm; \
dim_t cs_m = 1; \
if ( ( stor_order[bs_i] == 'C' ) || ( stor_order[bs_i] == 'c' ) ) \
{ \
cs_m = rs_m; \
rs_m = 1; \
} \
float* scl_fctr = ( float* )( ( post_op[bs_i]->matrix_add )->scale_factor ); \
dim_t scl_fctr_len = ( post_op[bs_i]->matrix_add )->scale_factor_len; \
temp_accum += GEN_FUNC_NAME(get_matrix_add_post_op_val_,BLAS_SFX) \
( *( ( C_type* )( post_op[bs_i]->matrix_add )->matrix + \
( i * rs_m ) + ( j * cs_m ) ), \
j, scl_fctr, scl_fctr_len ); \
} \
else if ( post_op[bs_i]->seq_vector[op_id] == MATRIX_MUL ) \
{ \
dim_t rs_m = ( post_op[bs_i]->matrix_mul )->ldm; \
dim_t cs_m = 1; \
if ( ( stor_order[bs_i] == 'C' ) || ( stor_order[bs_i] == 'c' ) ) \
{ \
cs_m = rs_m; \
rs_m = 1; \
} \
float* scl_fctr = ( float* )( ( post_op[bs_i]->matrix_mul )->scale_factor ); \
dim_t scl_fctr_len = ( post_op[bs_i]->matrix_mul )->scale_factor_len; \
temp_accum *= GEN_FUNC_NAME(get_matrix_mul_post_op_val_,BLAS_SFX) \
( *( ( C_type* )( post_op[bs_i]->matrix_mul )->matrix + \
( i * rs_m ) + ( j * cs_m ) ), \
j, scl_fctr, scl_fctr_len ); \
} \
else \
{} \
} \
} \
/* Need to convert to downscaled type if required.*/ \
mat_mul_get_output_type_val ## ACCUM_type ## C_type \
( \
&out_temp_accum, &temp_accum \
); \
\
float comp_float, ref_float; \
GEN_FUNC_NAME(C_type,_to_float)(*( c[bs_i] + ( rs_c * i ) + ( cs_c * j ) ), &comp_float); \
GEN_FUNC_NAME(C_type,_to_float)(out_temp_accum, &ref_float); \
\
if ( ( ( comp_float - ref_float ) > 1.0E-5 ) || \
( ( ref_float - comp_float ) > 1.0E-5 ) ) \
{ \
if ( fout ) \
{ \
fprintf( fout, "%s Failure input gemm:%ld, m: %ld, n: %ld, k: %ld," \
" lda: %ld, ldb: %ld, ldc: %ld, computed:%f, ref:%f, diff:%f\n", \
XSTR(BLAS_SFX), bs_i, m[bs_i], n[bs_i], k[bs_i], lda[bs_i], ldb[bs_i], ldc[bs_i], comp_float, \
ref_float, comp_float - ref_float); \
fflush( fout ); \
} \
printf("failure, gemm:%ld, m_index: %ld, n_index: %ld, k: %ld, computed:%f, ref:%f," \
"diff:%f\n", bs_i, i, j, k[bs_i], comp_float, ref_float, comp_float-ref_float); \
goto cleanup_acc; \
} \
} \
} \
} \
cleanup_acc: \
return; \
} \
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int16_t,int16_t,float,u8s8s16os16,u8s8s16os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int16_t,float,u8s8s16os8,u8s8s16os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,uint8_t,int16_t,float,u8s8s16ou8,u8s8s16ou8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int32_t,int32_t,float,u8s8s32os32,u8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int32_t,float,u8s8s32os8,u8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,float,float,float,bf16bf16f32of32,bf16bf16f32obf16)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,bfloat16,float,float,bf16bf16f32obf16,bf16bf16f32obf16)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(float,float,float,float,float,f32f32f32of32,f32f32f32of32)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int32_t,int32_t,float,s8s8s32os32,s8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int8_t,int32_t,float,s8s8s32os8,s8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int16_t,int16_t,float,s8s8s16os16,s8s8s16os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int8_t,int16_t,float,s8s8s16os8,s8s8s16os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,int8_t,float,float,float,bf16s4f32of32,bf16bf16f32obf16)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,int8_t,bfloat16,float,float,bf16s4f32obf16,bf16bf16f32obf16)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int16_t,float,int16_t,u8s8s16os16)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,int32_t,u8s8s32os32)
GEN_MAT_MUL_POST_OPS_CREATOR(bfloat16,float,float,bfloat16,bf16bf16f32of32)
GEN_MAT_MUL_POST_OPS_CREATOR(float,float,float,float,f32f32f32of32)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int32_t,float,int32_t,s8s8s32os32)
GEN_MAT_MUL_POST_OPS_CREATOR(int8_t,int16_t,float,int16_t,s8s8s16os16)
// Hack to fix compiler errors.
#define GET_B_TYPE_bf16bf16f32of32 bfloat16
#define GET_B_TYPE_u8s8s16os16 int8_t
#define GET_B_TYPE_u8s8s32os32 int8_t
#define GET_B_TYPE_f32f32f32of32 float
#define GET_B_TYPE_s8s8s32os32 int8_t
#define GET_B_TYPE_s8s8s16os16 int8_t
#define GEN_MAT_MUL_BENCH_MAIN_FUNC(A_type, B_type, C_type, Sum_type, BLAS_SFX, REORDER_SFX, INT4_REORDER_SFX) \
void mat_mul_bench_main_ ## BLAS_SFX \
( \
FILE* fin, \
FILE* fout, \
char* stor_order, \
char* transa, \
char* transb, \
char* op_a, \
char* op_b, \
dim_t bs, \
dim_t* m, \
dim_t* n, \
dim_t* k, \
dim_t* stride_a, \
dim_t* stride_b, \
dim_t* stride_c, \
char (*post_ops_str)[POST_OPS_STR_LEN], \
bool int4_testing /* Workaround to enable int4 B matrix testing. */\
) \
{ \
int32_t n_repeats = 1000;/* = bli_max( 30, bli_min(( 3e10 / ( ( int64_t )m * n * k )), 1000 ));*/ \
if ( global_n_repeat > 0 ) \
{ \
n_repeats = global_n_repeat; \
} \
\
/* creating an array of pointers to A, B and C matrices in the batch */ \
A_type** a = ( A_type** ) lpgemm_malloc( sizeof( A_type* ) * bs ); \
B_type** b = ( B_type** ) lpgemm_malloc( sizeof( B_type* ) * bs ); \
C_type** c = ( C_type** ) lpgemm_malloc( sizeof( C_type* ) * bs ); \
C_type** c_ref = ( C_type** ) lpgemm_malloc( sizeof( C_type* ) * bs ); \
B_type** b_gemm = ( B_type** ) lpgemm_malloc( sizeof( B_type*) * bs ); \
Sum_type alpha[bs]; \
Sum_type beta[bs]; \
aocl_post_op** post_op = (aocl_post_op**)lpgemm_malloc(sizeof(aocl_post_op*) * bs); \
for( dim_t i = 0; i < bs; i++ ) \
{ \
dim_t size_A = 0; \
dim_t size_B = 0; \
dim_t size_C = 0; \
if( ( stor_order[i] == 'r' ) || ( stor_order[i] == 'R' ) ) \
{ \
size_A = ( ( transa[i] == 'n' ) || ( transa[i] == 'N' ) ) ? m[i] * stride_a[i] : k[i] * stride_a[i]; \
size_B = ( ( transb[i] == 'n' ) || ( transb[i] == 'N' ) ) ? k[i] * stride_b[i] : n[i] * stride_b[i]; \
size_C = m[i] * stride_c[i]; \
} \
else \
{ \
size_A = ( ( transa[i] == 'n' ) || ( transa[i] == 'N' ) ) ? k[i] * stride_a[i] : m[i] * stride_a[i]; \
size_B = ( ( transb[i] == 'n' ) || ( transb[i] == 'N' ) ) ? n[i] * stride_b[i] : k[i] * stride_b[i]; \
size_C = n[i] * stride_c[i]; \
} \
a[i] = ( A_type* ) lpgemm_malloc( sizeof( A_type ) * size_A ); \
GEN_FUNC_NAME(fill_array_,A_type)(a[i], size_A ); \
b[i] = ( B_type* ) lpgemm_malloc( sizeof( B_type ) * size_B ); \
if ( int4_testing == FALSE ) \
{ \
GEN_FUNC_NAME(fill_array_,B_type)(b[i], size_B ); \
} \
else \
{ \
GEN_FUNC_NAME(fill_array_,int4_c_t)(b[i], size_B); \
} \
c[i] = ( C_type* ) lpgemm_malloc( sizeof( C_type ) * size_C ); \
c_ref[i] = ( C_type* ) lpgemm_malloc( sizeof( C_type ) * size_C ); \
if ( bench_mode == 'a' ) \
{ \
GEN_FUNC_NAME(fill_array_,C_type)(c[i], size_C ); \
memcpy(c_ref[i], c[i] , (size_C * sizeof(C_type))); \
} \
else \
{ \
memset( ( void* ) c[i], 0, sizeof( C_type ) * size_C ); \
memset( ( void* ) c_ref[i], 0, sizeof( C_type ) * size_C ); \
} \
\
if ( bench_mode == 'p' ) \
{ \
alpha[i] = 1; \
beta[i] = 0; \
} \
else if ( bench_mode == 'a' ) \
{ \
n_repeats = 1; \
alpha[i] = (i + 1) % 5; \
beta[i] = (i + 5 ) % 9; \
} \
if ( ( ( post_ops_str[i] != NULL ) && \
( strcmp( post_ops_str[i], "none" ) != 0 ) ) || \
( global_dscale_out == 'y' ) || ( global_pre_op == 'y' ) ) \
{ \
post_op[i] = GEN_FUNC_NAME(lpgemm_create_post_ops_struct_,REORDER_SFX)( m[i], n[i], k[i], post_ops_str[i], stor_order[i] ); \
if ( post_op[i] == NULL ) \
{ \
printf(" post op struct allocation failure, returning.\n"); \
return; \
} \
} \
else \
{ \
post_op[i] = NULL; \
} \
if ( ( op_b[i] == 'p' ) || ( op_b[i] == 'P' ) || ( op_b[i] == 'n' ) || ( op_b[i] == 'N' ) ) \
{ \
b_gemm[i] = b[i]; \
} \
else if ( ( op_b[i] == 'r' ) || ( op_b[i] == 'R' ) ) \
{ \
if ( int4_testing == FALSE ) \
{ \
siz_t b_reorder_buf_siz_req = \
GEN_FUNC_NAME(aocl_get_reorder_buf_size_,REORDER_SFX)( stor_order[i], transb[i], 'B', k[i], n[i] ); \
b_gemm[i] = ( B_type* ) lpgemm_malloc( b_reorder_buf_siz_req ); \
GEN_FUNC_NAME(aocl_reorder_,REORDER_SFX)( stor_order[i], transb[i], 'B', \
( GET_B_TYPE_ ## REORDER_SFX * )b[i], \
( GET_B_TYPE_ ## REORDER_SFX * )b_gemm[i], \
k[i], n[i], stride_b[i] ); \
} \
/* It has to be ensured, for now, only int4 testing takes else path. */ \
else \
{ \
siz_t b_reorder_buf_siz_req = \
GEN_FUNC_NAME(aocl_get_reorder_buf_size_,INT4_REORDER_SFX)( stor_order[i], transb[i], 'B', k[i], n[i] ); \
\
b_gemm[i] = ( B_type* ) lpgemm_malloc( b_reorder_buf_siz_req ); \
GEN_FUNC_NAME(aocl_reorder_,INT4_REORDER_SFX)( stor_order[i], transb[i], 'B', \
( int8_t* )b[i], ( int8_t* )b_gemm[i], k[i], n[i], stride_b[i] ); \
} \
} \
} /* Done with initializing inputs, */ \
/* Reordering has already been taken care of. */ \
GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \
( \
stor_order, transa, transb, op_a, op_b, n_repeats, bs, m, n, k, \
alpha, \
a, stride_a, \
b_gemm, stride_b, \
beta, \
c, stride_c, \
post_op \
); \
\
if ( bench_mode == 'a' ) \
{ \
printf(" Running accuracy check.\n"); \
fflush(stdout); \
GEN_FUNC_NAME(mat_mul_accuracy_check_driver_,BLAS_SFX) \
( \
fout, stor_order, transa, transb, bs, m, n, k, \
alpha, \
a, stride_a, \
b, stride_b, \
beta, \
c, stride_c, \
c_ref, stride_c, \
post_op, int4_testing \
); \
} \
\
for( dim_t i = 0; i < bs; i++ ) \
{ \
lpgemm_free( a[i] ); \
lpgemm_free( b[i] ); \
lpgemm_free( c[i] ); \
lpgemm_free( c_ref[i] ); \
lpgemm_destroy_post_ops_struct( post_op[i] ); \
if( ( op_b[i] == 'r' ) || ( op_b[i] == 'R' ) ) \
{ \
lpgemm_free( b_gemm[i]); \
} \
} \
lpgemm_free( a ); \
lpgemm_free( b ); \
lpgemm_free( c ); \
lpgemm_free( c_ref ); \
lpgemm_free( b_gemm ); \
lpgemm_free( post_op ); \
} \
GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,bfloat16,float,float,bf16bf16f32of32,bf16bf16f32of32,bf16s4f32of32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16,bf16bf16f32of32,bf16s4f32of32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(float,float,float,float,f32f32f32of32,f32f32f32of32,bf16s4f32of32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32,u8s8s32os32,u8s4s32os32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8,u8s8s32os32,u8s4s32os32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32,s8s8s32os32,u8s4s32os32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8,s8s8s32os32,u8s4s32os32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,int8_t,float,float,bf16s4f32of32,bf16bf16f32of32,bf16s4f32of32)
GEN_MAT_MUL_BENCH_MAIN_FUNC(bfloat16,int8_t,bfloat16,float,bf16s4f32obf16,bf16bf16f32of32,bf16s4f32of32)
int main( int argc, char** argv )
{
FILE* fin = NULL;
if ( argc < 5 )
{
printf
(
"Usage: ./bench_batch_lpgemm -i input.txt -m mode < -n 100 >\n" \
"--Mode is either a or p.\n" \
"\ta is used for accuracy testing.\n" \
"\tp is used for performance benchmarking.\n" \
"--n_repeats can be set optionally using -n arg.\n" \
"--Post ops can be executed optionaly by providing a coma separated\n" \
" list of post-ops in the input file. Following post-ops are supported:\n" \
" 1. bias\n" \
" 2. 4 activators\n" \
" a. relu\n" \
" b. prelu\n" \
" c. gelu_tanh\n" \
" d. gelu_erf\n" \
" 3.clip\n" \
" It is to be noted only one activator can be used at a time.\n" \
" If more than one activator is used, only the first activator is\n" \
" applied and the other activators are ignored.\n" \
" Downscaled api's are used to enable quantization workflows.\n" \
" Following downscaled api's are supported:\n" \
" 1. bf16bf16f32of32 -d bf16 = bf16bf16f32obf16.\n" \
" Example: ./bench_batch_lpgemm -m a -n 2 -i input.txt\n" \
);
exit( 1 );
}
char* file_name = NULL;
#define MAX_LINE_LENGTH 256 // read first line containing op str and batch size
char line[MAX_LINE_LENGTH];
#define GEMM_TYPE_STR_LEN 24
char gemm_type_str[GEMM_TYPE_STR_LEN];
// Parse CLI arguments.
getopt_t state;
// Initialize the state for running bli_getopt(). Here, 0 is the
// initial value for opterr, which suppresses error messages.
bli_getopt_init_state( 0, &state );
int opt;
// Process all option arguments until we get a -1, which means we're done.
while( (opt = bli_getopt( argc, argv, "i:m:n:", &state )) != -1 )
{
char opt_ch = ( char )opt;
switch( opt_ch )
{
case 'i':
file_name = state.optarg;
break;
case 'm':
bench_mode = ( ( ( *state.optarg ) == 'a' ) || ( ( *state.optarg ) == 'p' ) ) ? ( *state.optarg ) : 'p';
break;
case 'n':
global_n_repeat = ( atoi( state.optarg ) > 0 ) ? atoi( state.optarg ) : 0;
break;
default:
break;
}
}
if ( bench_mode == 'p' )
{
printf( "Running bench in performance benchmarking mode.\n" );
}
else if ( bench_mode == 'a' )
{
printf( "Running bench in accuracy/correctness testing mode.\n" );
}
if ( file_name == NULL )
{
printf( " File name provided is invalid.\n" );
exit( 1 );
}
fin = fopen( file_name, "r" );
if (fin == NULL)
{
printf( "Error opening the file %s\n", argv[1] );
exit( 1 );
}
FILE* fout = NULL;
fout = fopen( "lpgemm_accuracy_test_failures.txt", "w" );
// batch size
dim_t bs;
const dim_t len_list_omp_cores_for_testing = 2;
const dim_t list_omp_cores_for_testing[2] = { 1, 128 };
dim_t core_index = 0;
bool can_run = TRUE;
while ( ( can_run == TRUE ) && ( fseek( fin, 0L, SEEK_SET ) == 0 ) )
{
if ( bench_mode == 'p' )
{
can_run = FALSE;
}
else if ( bench_mode == 'a' )
{
// For accuracy testing, we test accuracy using multiple different
// number of cores. This helps uncover any bugs related to over
// subscription or varying thread factorizations.
// Set current number of cores.
#ifdef BLIS_ENABLE_OPENMP
omp_set_num_threads( list_omp_cores_for_testing[core_index] );
#endif
printf( "Accuracy test using %ld threads.\n",
list_omp_cores_for_testing[core_index] );
core_index++;
if ( core_index < len_list_omp_cores_for_testing )
{
can_run = TRUE;
}
else
{
can_run = FALSE;
}
}
// Process the file until no more data remains
while (fgets(line, MAX_LINE_LENGTH, fin))
{
// Step 1: Extract 'op' and 'bs' from the first line
if (sscanf(line, "%23[^:]:bs=%ld", gemm_type_str, &bs) != 2)
{
printf("Error: Failed to parse the first line.\n");
break;
}
char op_a[bs], op_b[bs];
char stor_order[bs];
char transa[bs], transb[bs];
dim_t m[bs], n[bs], k[bs];
dim_t stride_a[bs], stride_b[bs], stride_c[bs];
char post_ops_str[bs][POST_OPS_STR_LEN];
char post_ops_str_dest[bs][POST_OPS_STR_LEN]; //Strtok is used to parse, need to maintain a copy.
// Step 2: Read the next 'bs' number of lines and parse them
for (int i = 0; i < bs; i++)
{
if (fgets(line, MAX_LINE_LENGTH, fin)) {
if (sscanf( line, "%c %c %c %c %c " INT_FS INT_FS INT_FS
INT_FS INT_FS INT_FS " %s", &(stor_order[i]), &(transa[i]),
&(transb[i]), &(op_a[i]), &(op_b[i]), &(m[i]), &(n[i]), &(k[i]),
&(stride_a[i]), &(stride_b[i]), &(stride_c[i]), post_ops_str[i] ) == 12) {
} else {
printf(" Error parsing line %d\n", i + 1);
}
} else {
printf("Error: Not enough lines to match 'bs' value.\n");
break;
}
stor_order[i] = ( ( stor_order[i] == 'r' ) || ( stor_order[i] == 'R' ) ||
( stor_order[i] == 'c' ) || ( stor_order[i] == 'C' ) ) ?
stor_order[i] : 'r';
}
if ( ( strcmp( gemm_type_str, "bf16bf16f32of32" ) == 0 ) ||
( strcmp( gemm_type_str, "*" ) == 0 ) )
{
for( dim_t i = 0; i < bs; i++ )
strncpy( post_ops_str_dest[i], post_ops_str[i], POST_OPS_STR_LEN );
global_dscale_out = 'n';
global_pre_op = 'n';
GEN_FUNC_NAME(mat_mul_bench_main_, bf16bf16f32of32)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
bs, m, n, k, stride_a, stride_b, stride_c,
post_ops_str_dest, FALSE
);
}
if ( ( strcmp( gemm_type_str, "bf16bf16f32obf16" ) == 0 ) ||
( strcmp( gemm_type_str, "*" ) == 0 ) )
{
for( dim_t i = 0; i < bs; i++ )
strncpy( post_ops_str_dest[i], post_ops_str[i], POST_OPS_STR_LEN );
global_dscale_out = 'n';
global_pre_op = 'n';
GEN_FUNC_NAME(mat_mul_bench_main_, bf16bf16f32obf16)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
bs, m, n, k, stride_a, stride_b, stride_c,
post_ops_str_dest, FALSE
);
}
if ( ( strcmp( gemm_type_str, "f32f32f32of32" ) == 0 ) ||
( strcmp( gemm_type_str, "*" ) == 0 ) )
{
for( dim_t i = 0; i < bs; i++ )
strncpy( post_ops_str_dest[i], post_ops_str[i], POST_OPS_STR_LEN );
global_can_dscale = 'y';
global_dscale_out = 'n';
global_pre_op = 'n';
GEN_FUNC_NAME(mat_mul_bench_main_,f32f32f32of32)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
bs, m, n, k, stride_a, stride_b, stride_c,
post_ops_str_dest, FALSE
);
}
if ( ( strcmp( gemm_type_str, "u8s8s32os32" ) == 0 ) ||
( strcmp( gemm_type_str, "*" ) == 0 ) )
{
// Copy the original post op str to a temp string buffer.
// Done so that strtok can be applied on the same (strtok
// is a destructive parser.
for( dim_t i = 0; i < bs; i++ )
strncpy( post_ops_str_dest[i], post_ops_str[i], POST_OPS_STR_LEN );
global_dscale_out = 'n';
global_pre_op = 'n';
DSCALE_CLIP_MIN = INT_MIN;
DSCALE_CLIP_MAX = INT_MAX;
GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os32)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
bs, m, n, k, stride_a, stride_b, stride_c,
post_ops_str_dest, FALSE
);
}
if ( ( strcmp( gemm_type_str, "u8s8s32os8" ) == 0 ) ||
( strcmp( gemm_type_str, "*" ) == 0 ) )
{
// Copy the original post op str to a temp string buffer.
// Done so that strtok can be applied on the same (strtok
// is a destructive parser.
for( dim_t i = 0; i < bs; i++ )
strncpy( post_ops_str_dest[i], post_ops_str[i], POST_OPS_STR_LEN );
global_dscale_out = 'y';
global_pre_op = 'n';
DSCALE_CLIP_MIN = -128;
DSCALE_CLIP_MAX = +127;
GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os8)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
bs, m, n, k, stride_a, stride_b, stride_c,
post_ops_str_dest, FALSE
);
}
if ( ( strcmp( gemm_type_str, "s8s8s32os32" ) == 0 ) ||
( strcmp( gemm_type_str, "*" ) == 0 ) )
{
// Copy the original post op str to a temp string buffer.
// Done so that strtok can be applied on the same (strtok
// is a destructive parser.
for( dim_t i = 0; i < bs; i++ )
strncpy( post_ops_str_dest[i], post_ops_str[i], POST_OPS_STR_LEN );
global_dscale_out = 'n';
global_pre_op = 'n';
DSCALE_CLIP_MIN = INT_MIN;
DSCALE_CLIP_MAX = INT_MAX;
GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s32os32)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
bs, m, n, k, stride_a, stride_b, stride_c,
post_ops_str_dest, FALSE
);
}
if ( ( strcmp( gemm_type_str, "s8s8s32os8" ) == 0 ) ||
( strcmp( gemm_type_str, "*" ) == 0 ) )
{
// Copy the original post op str to a temp string buffer.
// Done so that strtok can be applied on the same (strtok
// is a destructive parser.
for( dim_t i = 0; i < bs; i++ )
strncpy( post_ops_str_dest[i], post_ops_str[i], POST_OPS_STR_LEN );
global_dscale_out = 'y';
global_pre_op = 'n';
DSCALE_CLIP_MIN = -128;
DSCALE_CLIP_MAX = +127;
GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s32os8)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
bs, m, n, k, stride_a, stride_b, stride_c,
post_ops_str_dest, FALSE
);
}
if ( strcmp( gemm_type_str, "bf16s4f32of32" ) == 0 )
{
// Copy the original post op str to a temp string buffer.
// Done so that strtok can be applied on the same (strtok
// is a destructive parser.
for( dim_t i = 0; i < bs; i++ )
{
strncpy( post_ops_str_dest[i], post_ops_str[i], POST_OPS_STR_LEN );
if ( ( op_b[i] != 'r' ) && ( op_b[i] != 'R' ) )
{
printf("Int4 B matrix only permitted if B reodering "
"is enabled.\n");
goto skip_exec;
}
}
global_dscale_out = 'n';
global_pre_op = 'y';
GEN_FUNC_NAME(mat_mul_bench_main_, bf16s4f32of32)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
bs, m, n, k, stride_a, stride_b, stride_c,
post_ops_str_dest, TRUE
);
}
if ( strcmp( gemm_type_str, "bf16s4f32obf16" ) == 0 )
{
// Copy the original post op str to a temp string buffer.
// Done so that strtok can be applied on the same (strtok
// is a destructive parser.
for( dim_t i = 0; i < bs; i++ )
{
strncpy( post_ops_str_dest[i], post_ops_str[i], POST_OPS_STR_LEN );
if ( ( op_b[i] != 'r' ) && ( op_b[i] != 'R' ) )
{
printf("Int4 B matrix only permitted if B reodering "
"is enabled.\n");
goto skip_exec;
}
}
global_dscale_out = 'y';
global_pre_op = 'y';
GEN_FUNC_NAME(mat_mul_bench_main_, bf16s4f32obf16)
(
fin, fout, stor_order, transa, transb, op_a, op_b,
bs, m, n, k, stride_a, stride_b, stride_c,
post_ops_str_dest, TRUE
);
}
skip_exec:;
}
}
if ( fin )
{
fclose( fin );
}
if ( fout )
{
fclose( fout );
}
return 0;
}