Added destination scale type check in INT8 API's

- Updated the S8 main, GEMV, m_, n_ and mn_ fringe kernels to support
   multiple scale types for vector and scalar scales

 - Updated the U8 main, GEMV, m_, n_, extMR_ and mn_ fringe kernels to
   support multiple scale types for vector and scalar scales

 - Updated the bench to accommodate multiple scale type input, and
   modified the downscale_accuracy_check_ to verify with multiple scale
   type inputs.

AMD Internal: [ SWLCSG-3304 ]

Change-Id: I7b9f3ec8ea830d3265f72d18a0aa36086e14a86e
This commit is contained in:
varshav
2025-03-24 06:30:07 +00:00
committed by Nallani Bhaskar
parent 350c7186e5
commit 81d219e3f8
21 changed files with 4469 additions and 1094 deletions

View File

@@ -616,7 +616,7 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
else if ( post_op[bs_i]->seq_vector[op_id] == SCALE ) \
{ \
post_temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_DOWNSCALE_SFX) \
(post_temp_accum, post_op[bs_i], j); \
(post_temp_accum, post_op[bs_i], j, ( post_op[bs_i]->sum )->sf_stor_type); \
} \
else if ( post_op[bs_i]->seq_vector[op_id] == MATRIX_ADD ) \
{ \

View File

@@ -598,7 +598,7 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
else if ( post_op->seq_vector[op_id] == SCALE ) \
{ \
post_temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_DOWNSCALE_SFX) \
(post_temp_accum, post_op, j); \
(post_temp_accum, post_op, j, ( post_op->sum )->sf_stor_type); \
} \
else if ( post_op->seq_vector[op_id] == MATRIX_ADD ) \
{ \

View File

@@ -261,7 +261,7 @@ static inline void fill_array_bfloat16( void* arr, dim_t size )
#endif
for ( dim_t i = 0; i < size; ++i )
{
c_float[i] = (float)(i % 5);
c_float[i] = (float)( i % 5);
}
convert_float_arr_to_bf16( c_float, arr, size );
if ( c_float != NULL )
@@ -799,12 +799,48 @@ void print_matrix_bfloat16
}
}
float convert_scale_store_type_to_float
(
aocl_post_op* post_op,
AOCL_PARAMS_STORAGE_TYPES sf_stor_type,
dim_t j_scale
)
{
float scale_float = 0.0;
if(sf_stor_type == AOCL_GEMM_BF16)
{
bfloat16_to_float( *( ( bfloat16* )( post_op->sum )->scale_factor + j_scale ),
&scale_float );
}
else if(sf_stor_type == AOCL_GEMM_INT32)
{
int32_t_to_float( *( ( int32_t* )( post_op->sum )->scale_factor + j_scale ),
&scale_float );
}
else if(sf_stor_type == AOCL_GEMM_INT8)
{
int8_t_to_float( *( ( int8_t* )( post_op->sum )->scale_factor + j_scale ),
&scale_float );
}
else if(sf_stor_type == AOCL_GEMM_UINT8)
{
uint8_t_to_float( *( ( uint8_t* )( post_op->sum )->scale_factor + j_scale ),
&scale_float );
}
else
{
scale_float = *( ( float* )( post_op->sum )->scale_factor + j_scale );
}
return scale_float;
}
#define GEN_MAT_MUL_ACC_CHK_DOWNSCALE(ZP_type,C_type,ACCUM_type,SCALE_type,BLAS_DOWNSCALE_SFX) \
static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX \
(\
ACCUM_type temp_accum,\
aocl_post_op* post_op, \
dim_t j \
dim_t j, \
AOCL_PARAMS_STORAGE_TYPES sf_stor_type \
)\
{ \
dim_t j_scale = j; \
@@ -819,10 +855,11 @@ static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX
j_zp = 0; \
} \
\
float temp_sf = convert_scale_store_type_to_float(post_op, sf_stor_type, j_scale); \
ACCUM_type out_temp_accum = \
( ACCUM_type )min( \
max( nearbyintf( ( SCALE_type )( temp_accum ) * \
( *( ( SCALE_type* )( post_op->sum )->scale_factor + j_scale ) ) ) + \
( temp_sf ) ) + \
*( ( ZP_type* )( post_op->sum )->zero_point + j_zp ), \
DSCALE_CLIP_MIN ), \
DSCALE_CLIP_MAX ); \
@@ -834,9 +871,11 @@ static inline float mat_mul_accuracy_check_downscale_bf16bf16f32obf16
(
float temp_accum,
aocl_post_op* post_op,
dim_t j
dim_t j,
AOCL_PARAMS_STORAGE_TYPES sf_stor_type
)
{
( void ) sf_stor_type;
dim_t j_scale = j;
if ( ( post_op->sum )->scale_factor_len == 1 )
{
@@ -861,9 +900,11 @@ static inline float mat_mul_accuracy_check_downscale_f32f32f32of32
(
float temp_accum,
aocl_post_op* post_op,
dim_t j
dim_t j,
AOCL_PARAMS_STORAGE_TYPES sf_stor_type
)
{
( void ) sf_stor_type;
dim_t j_scale = j;
if ( ( post_op->sum )->scale_factor_len == 1 )
{
@@ -1415,6 +1456,8 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
char * bias_stor_type = ""; \
bool is_zp_stor_type = FALSE; \
char* zp_stor_type = ""; \
bool is_sf_stor_type = FALSE; \
char* sf_stor_type = ""; \
bool is_matadd_stor_type = FALSE; \
char* matadd_stor_type = ""; \
bool is_matmul_stor_type = FALSE; \
@@ -1537,6 +1580,39 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
is_scalar_scale = TRUE; \
} \
} \
else if ( strcmp( ops_tok, "sf_stor_type" ) == 0) \
{ \
ops_tok = strtok( NULL, ", " ); \
if( ( strcmp( ops_tok, "na" ) == 0 ) ) \
{ \
is_sf_stor_type = FALSE; \
} \
else if ( ( strcmp( ops_tok, "f32" ) == 0 ) ) \
{ \
is_sf_stor_type = TRUE; \
sf_stor_type = "F32"; \
} \
else if ( ( strcmp( ops_tok, "bf16" ) == 0 ) ) \
{ \
is_sf_stor_type = TRUE; \
sf_stor_type = "BF16"; \
} \
else if ( ( strcmp( ops_tok, "s32" ) == 0 ) ) \
{ \
is_sf_stor_type = TRUE; \
sf_stor_type = "S32"; \
} \
else if ( ( strcmp( ops_tok, "s8" ) == 0 ) ) \
{ \
is_sf_stor_type = TRUE; \
sf_stor_type = "S8"; \
} \
else if ( ( strcmp( ops_tok, "u8" ) == 0 ) ) \
{ \
is_sf_stor_type = TRUE; \
sf_stor_type = "U8"; \
} \
} \
else if ( strcmp( ops_tok, "zp" ) == 0 ) \
{ \
ops_tok = strtok( NULL, ", " ); \
@@ -2007,16 +2083,78 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
} \
\
/* Allocate scale buffer, return early if alloc fails.*/ \
( post_ops->sum )->scale_factor = malloc( n_scale * sizeof( DSCALE_type ) ); \
if ( ( post_ops->sum )->scale_factor == NULL ) \
{ \
goto err_handler; \
} \
/* Fill scale factor */ \
DSCALE_type* temp_dscale_ptr = ( DSCALE_type* )( post_ops->sum )->scale_factor; \
GEN_FUNC_NAME(fill_array_,DSCALE_type)(temp_dscale_ptr, n_scale); \
( post_ops->sum )->scale_factor_len = n_scale; \
if(strcmp(#BLAS_SFX, "u8s8s32ou8")) for(dim_t i=0;i<n_scale;i++) temp_dscale_ptr[i] = abs(temp_dscale_ptr[i]);\
if(is_sf_stor_type == TRUE) \
{ \
if( ( strcmp( sf_stor_type, "BF16" ) == 0 ) ) \
{ \
( post_ops->sum )->sf_stor_type = AOCL_GEMM_BF16; \
( post_ops->sum )->scale_factor = malloc( n_scale * sizeof( bfloat16 ) ); \
if ( ( post_ops->sum )->scale_factor == NULL ) \
{ \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_,bfloat16)( ( post_ops->sum )->scale_factor, n_scale ); \
( post_ops->sum )->scale_factor_len = n_scale; \
} \
else if( ( strcmp( sf_stor_type, "F32" ) == 0 ) ) \
{ \
( post_ops->sum )->sf_stor_type = AOCL_GEMM_F32; \
( post_ops->sum )->scale_factor = malloc( n_scale * sizeof( float ) ); \
if ( ( post_ops->sum )->scale_factor == NULL ) \
{ \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_,float)( ( post_ops->sum )->scale_factor, n_scale ); \
( post_ops->sum )->scale_factor_len = n_scale; \
} \
else if( ( strcmp( sf_stor_type, "S32" ) == 0 ) ) \
{ \
( post_ops->sum )->sf_stor_type = AOCL_GEMM_INT32; \
( post_ops->sum )->scale_factor = malloc( n_scale * sizeof( int32_t ) ); \
if ( ( post_ops->sum )->scale_factor == NULL ) \
{ \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_,int32_t)( ( post_ops->sum )->scale_factor, n_scale ); \
( post_ops->sum )->scale_factor_len = n_scale; \
} \
else if( ( strcmp( sf_stor_type, "S8" ) == 0 ) ) \
{ \
( post_ops->sum )->sf_stor_type = AOCL_GEMM_INT8; \
( post_ops->sum )->scale_factor = malloc( n_scale * sizeof( int8_t ) ); \
if ( ( post_ops->sum )->scale_factor == NULL ) \
{ \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_,int8_t)( ( post_ops->sum )->scale_factor, n_scale ); \
( post_ops->sum )->scale_factor_len = n_scale; \
} \
else if( ( strcmp( sf_stor_type, "U8" ) == 0 ) ) \
{ \
( post_ops->sum )->sf_stor_type = AOCL_GEMM_UINT8; \
( post_ops->sum )->scale_factor = malloc( n_scale * sizeof( uint8_t ) ); \
if ( ( post_ops->sum )->scale_factor == NULL ) \
{ \
goto err_handler; \
} \
GEN_FUNC_NAME(fill_array_,uint8_t)( ( post_ops->sum )->scale_factor, n_scale ); \
( post_ops->sum )->scale_factor_len = n_scale; \
} \
else {} \
} \
else \
{ \
( post_ops->sum )->scale_factor = malloc( n_scale * sizeof( DSCALE_type ) ); \
if ( ( post_ops->sum )->scale_factor == NULL ) \
{ \
goto err_handler; \
} \
DSCALE_type* temp_dscale_ptr = ( DSCALE_type* )( post_ops->sum )->scale_factor; \
GEN_FUNC_NAME(fill_array_,DSCALE_type)(temp_dscale_ptr, n_scale); \
( post_ops->sum )->scale_factor_len = n_scale; \
if(strcmp(#BLAS_SFX, "u8s8s32ou8")) for(dim_t i=0;i<n_scale;i++) temp_dscale_ptr[i] = abs(temp_dscale_ptr[i]);\
} \
\
if(is_zp_stor_type == TRUE) \
{ \
@@ -2086,7 +2224,7 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
goto err_handler; \
} \
C_DSCALE_type* temp_dzero_point_ptr = ( C_DSCALE_type* )( post_ops->sum )->zero_point; \
GEN_FUNC_NAME(fill_array_,C_DSCALE_type)(temp_dzero_point_ptr, n_zp); \
GEN_FUNC_NAME(fill_array_,C_DSCALE_type)(temp_dzero_point_ptr, n_zp); \
( post_ops->sum )->zero_point_len = n_zp; \
if(strcmp(#BLAS_SFX, "u8s8s32ou8")) for(dim_t i=0;i<n_zp;i++) temp_dzero_point_ptr[i] = abs(temp_dzero_point_ptr[i]);\
} \

View File

@@ -404,7 +404,7 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
else if ( post_op->seq_vector[op_id] == SCALE ) \
{ \
temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_SFX) \
(temp_accum, post_op, j); \
(temp_accum, post_op, j, ( post_op->sum )->sf_stor_type); \
} \
else if ( post_op->seq_vector[op_id] == MATRIX_ADD ) \
{ \