mirror of
https://github.com/amd/blis.git
synced 2026-04-23 09:08:51 +00:00
Added destination scale type check in INT8 API's
- Updated the S8 main, GEMV, m_, n_ and mn_ fringe kernels to support multiple scale types for vector and scalar scales - Updated the U8 main, GEMV, m_, n_, extMR_ and mn_ fringe kernels to support multiple scale types for vector and scalar scales - Updated the bench to accommodate multiple scale type input, and modified the downscale_accuracy_check_ to verify with multiple scale type inputs. AMD Internal: [ SWLCSG-3304 ] Change-Id: I7b9f3ec8ea830d3265f72d18a0aa36086e14a86e
This commit is contained in:
@@ -616,7 +616,7 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
|
||||
else if ( post_op[bs_i]->seq_vector[op_id] == SCALE ) \
|
||||
{ \
|
||||
post_temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_DOWNSCALE_SFX) \
|
||||
(post_temp_accum, post_op[bs_i], j); \
|
||||
(post_temp_accum, post_op[bs_i], j, ( post_op[bs_i]->sum )->sf_stor_type); \
|
||||
} \
|
||||
else if ( post_op[bs_i]->seq_vector[op_id] == MATRIX_ADD ) \
|
||||
{ \
|
||||
|
||||
@@ -598,7 +598,7 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
|
||||
else if ( post_op->seq_vector[op_id] == SCALE ) \
|
||||
{ \
|
||||
post_temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_DOWNSCALE_SFX) \
|
||||
(post_temp_accum, post_op, j); \
|
||||
(post_temp_accum, post_op, j, ( post_op->sum )->sf_stor_type); \
|
||||
} \
|
||||
else if ( post_op->seq_vector[op_id] == MATRIX_ADD ) \
|
||||
{ \
|
||||
|
||||
@@ -261,7 +261,7 @@ static inline void fill_array_bfloat16( void* arr, dim_t size )
|
||||
#endif
|
||||
for ( dim_t i = 0; i < size; ++i )
|
||||
{
|
||||
c_float[i] = (float)(i % 5);
|
||||
c_float[i] = (float)( i % 5);
|
||||
}
|
||||
convert_float_arr_to_bf16( c_float, arr, size );
|
||||
if ( c_float != NULL )
|
||||
@@ -799,12 +799,48 @@ void print_matrix_bfloat16
|
||||
}
|
||||
}
|
||||
|
||||
float convert_scale_store_type_to_float
|
||||
(
|
||||
aocl_post_op* post_op,
|
||||
AOCL_PARAMS_STORAGE_TYPES sf_stor_type,
|
||||
dim_t j_scale
|
||||
)
|
||||
{
|
||||
float scale_float = 0.0;
|
||||
if(sf_stor_type == AOCL_GEMM_BF16)
|
||||
{
|
||||
bfloat16_to_float( *( ( bfloat16* )( post_op->sum )->scale_factor + j_scale ),
|
||||
&scale_float );
|
||||
}
|
||||
else if(sf_stor_type == AOCL_GEMM_INT32)
|
||||
{
|
||||
int32_t_to_float( *( ( int32_t* )( post_op->sum )->scale_factor + j_scale ),
|
||||
&scale_float );
|
||||
}
|
||||
else if(sf_stor_type == AOCL_GEMM_INT8)
|
||||
{
|
||||
int8_t_to_float( *( ( int8_t* )( post_op->sum )->scale_factor + j_scale ),
|
||||
&scale_float );
|
||||
}
|
||||
else if(sf_stor_type == AOCL_GEMM_UINT8)
|
||||
{
|
||||
uint8_t_to_float( *( ( uint8_t* )( post_op->sum )->scale_factor + j_scale ),
|
||||
&scale_float );
|
||||
}
|
||||
else
|
||||
{
|
||||
scale_float = *( ( float* )( post_op->sum )->scale_factor + j_scale );
|
||||
}
|
||||
return scale_float;
|
||||
}
|
||||
|
||||
#define GEN_MAT_MUL_ACC_CHK_DOWNSCALE(ZP_type,C_type,ACCUM_type,SCALE_type,BLAS_DOWNSCALE_SFX) \
|
||||
static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX \
|
||||
(\
|
||||
ACCUM_type temp_accum,\
|
||||
aocl_post_op* post_op, \
|
||||
dim_t j \
|
||||
dim_t j, \
|
||||
AOCL_PARAMS_STORAGE_TYPES sf_stor_type \
|
||||
)\
|
||||
{ \
|
||||
dim_t j_scale = j; \
|
||||
@@ -819,10 +855,11 @@ static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX
|
||||
j_zp = 0; \
|
||||
} \
|
||||
\
|
||||
float temp_sf = convert_scale_store_type_to_float(post_op, sf_stor_type, j_scale); \
|
||||
ACCUM_type out_temp_accum = \
|
||||
( ACCUM_type )min( \
|
||||
max( nearbyintf( ( SCALE_type )( temp_accum ) * \
|
||||
( *( ( SCALE_type* )( post_op->sum )->scale_factor + j_scale ) ) ) + \
|
||||
( temp_sf ) ) + \
|
||||
*( ( ZP_type* )( post_op->sum )->zero_point + j_zp ), \
|
||||
DSCALE_CLIP_MIN ), \
|
||||
DSCALE_CLIP_MAX ); \
|
||||
@@ -834,9 +871,11 @@ static inline float mat_mul_accuracy_check_downscale_bf16bf16f32obf16
|
||||
(
|
||||
float temp_accum,
|
||||
aocl_post_op* post_op,
|
||||
dim_t j
|
||||
dim_t j,
|
||||
AOCL_PARAMS_STORAGE_TYPES sf_stor_type
|
||||
)
|
||||
{
|
||||
( void ) sf_stor_type;
|
||||
dim_t j_scale = j;
|
||||
if ( ( post_op->sum )->scale_factor_len == 1 )
|
||||
{
|
||||
@@ -861,9 +900,11 @@ static inline float mat_mul_accuracy_check_downscale_f32f32f32of32
|
||||
(
|
||||
float temp_accum,
|
||||
aocl_post_op* post_op,
|
||||
dim_t j
|
||||
dim_t j,
|
||||
AOCL_PARAMS_STORAGE_TYPES sf_stor_type
|
||||
)
|
||||
{
|
||||
( void ) sf_stor_type;
|
||||
dim_t j_scale = j;
|
||||
if ( ( post_op->sum )->scale_factor_len == 1 )
|
||||
{
|
||||
@@ -1415,6 +1456,8 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
char * bias_stor_type = ""; \
|
||||
bool is_zp_stor_type = FALSE; \
|
||||
char* zp_stor_type = ""; \
|
||||
bool is_sf_stor_type = FALSE; \
|
||||
char* sf_stor_type = ""; \
|
||||
bool is_matadd_stor_type = FALSE; \
|
||||
char* matadd_stor_type = ""; \
|
||||
bool is_matmul_stor_type = FALSE; \
|
||||
@@ -1537,6 +1580,39 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
is_scalar_scale = TRUE; \
|
||||
} \
|
||||
} \
|
||||
else if ( strcmp( ops_tok, "sf_stor_type" ) == 0) \
|
||||
{ \
|
||||
ops_tok = strtok( NULL, ", " ); \
|
||||
if( ( strcmp( ops_tok, "na" ) == 0 ) ) \
|
||||
{ \
|
||||
is_sf_stor_type = FALSE; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "f32" ) == 0 ) ) \
|
||||
{ \
|
||||
is_sf_stor_type = TRUE; \
|
||||
sf_stor_type = "F32"; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "bf16" ) == 0 ) ) \
|
||||
{ \
|
||||
is_sf_stor_type = TRUE; \
|
||||
sf_stor_type = "BF16"; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "s32" ) == 0 ) ) \
|
||||
{ \
|
||||
is_sf_stor_type = TRUE; \
|
||||
sf_stor_type = "S32"; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "s8" ) == 0 ) ) \
|
||||
{ \
|
||||
is_sf_stor_type = TRUE; \
|
||||
sf_stor_type = "S8"; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "u8" ) == 0 ) ) \
|
||||
{ \
|
||||
is_sf_stor_type = TRUE; \
|
||||
sf_stor_type = "U8"; \
|
||||
} \
|
||||
} \
|
||||
else if ( strcmp( ops_tok, "zp" ) == 0 ) \
|
||||
{ \
|
||||
ops_tok = strtok( NULL, ", " ); \
|
||||
@@ -2007,16 +2083,78 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
} \
|
||||
\
|
||||
/* Allocate scale buffer, return early if alloc fails.*/ \
|
||||
( post_ops->sum )->scale_factor = malloc( n_scale * sizeof( DSCALE_type ) ); \
|
||||
if ( ( post_ops->sum )->scale_factor == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
/* Fill scale factor */ \
|
||||
DSCALE_type* temp_dscale_ptr = ( DSCALE_type* )( post_ops->sum )->scale_factor; \
|
||||
GEN_FUNC_NAME(fill_array_,DSCALE_type)(temp_dscale_ptr, n_scale); \
|
||||
( post_ops->sum )->scale_factor_len = n_scale; \
|
||||
if(strcmp(#BLAS_SFX, "u8s8s32ou8")) for(dim_t i=0;i<n_scale;i++) temp_dscale_ptr[i] = abs(temp_dscale_ptr[i]);\
|
||||
if(is_sf_stor_type == TRUE) \
|
||||
{ \
|
||||
if( ( strcmp( sf_stor_type, "BF16" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->sum )->sf_stor_type = AOCL_GEMM_BF16; \
|
||||
( post_ops->sum )->scale_factor = malloc( n_scale * sizeof( bfloat16 ) ); \
|
||||
if ( ( post_ops->sum )->scale_factor == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_,bfloat16)( ( post_ops->sum )->scale_factor, n_scale ); \
|
||||
( post_ops->sum )->scale_factor_len = n_scale; \
|
||||
} \
|
||||
else if( ( strcmp( sf_stor_type, "F32" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->sum )->sf_stor_type = AOCL_GEMM_F32; \
|
||||
( post_ops->sum )->scale_factor = malloc( n_scale * sizeof( float ) ); \
|
||||
if ( ( post_ops->sum )->scale_factor == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_,float)( ( post_ops->sum )->scale_factor, n_scale ); \
|
||||
( post_ops->sum )->scale_factor_len = n_scale; \
|
||||
} \
|
||||
else if( ( strcmp( sf_stor_type, "S32" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->sum )->sf_stor_type = AOCL_GEMM_INT32; \
|
||||
( post_ops->sum )->scale_factor = malloc( n_scale * sizeof( int32_t ) ); \
|
||||
if ( ( post_ops->sum )->scale_factor == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_,int32_t)( ( post_ops->sum )->scale_factor, n_scale ); \
|
||||
( post_ops->sum )->scale_factor_len = n_scale; \
|
||||
} \
|
||||
else if( ( strcmp( sf_stor_type, "S8" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->sum )->sf_stor_type = AOCL_GEMM_INT8; \
|
||||
( post_ops->sum )->scale_factor = malloc( n_scale * sizeof( int8_t ) ); \
|
||||
if ( ( post_ops->sum )->scale_factor == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_,int8_t)( ( post_ops->sum )->scale_factor, n_scale ); \
|
||||
( post_ops->sum )->scale_factor_len = n_scale; \
|
||||
} \
|
||||
else if( ( strcmp( sf_stor_type, "U8" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->sum )->sf_stor_type = AOCL_GEMM_UINT8; \
|
||||
( post_ops->sum )->scale_factor = malloc( n_scale * sizeof( uint8_t ) ); \
|
||||
if ( ( post_ops->sum )->scale_factor == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_,uint8_t)( ( post_ops->sum )->scale_factor, n_scale ); \
|
||||
( post_ops->sum )->scale_factor_len = n_scale; \
|
||||
} \
|
||||
else {} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
( post_ops->sum )->scale_factor = malloc( n_scale * sizeof( DSCALE_type ) ); \
|
||||
if ( ( post_ops->sum )->scale_factor == NULL ) \
|
||||
{ \
|
||||
goto err_handler; \
|
||||
} \
|
||||
DSCALE_type* temp_dscale_ptr = ( DSCALE_type* )( post_ops->sum )->scale_factor; \
|
||||
GEN_FUNC_NAME(fill_array_,DSCALE_type)(temp_dscale_ptr, n_scale); \
|
||||
( post_ops->sum )->scale_factor_len = n_scale; \
|
||||
if(strcmp(#BLAS_SFX, "u8s8s32ou8")) for(dim_t i=0;i<n_scale;i++) temp_dscale_ptr[i] = abs(temp_dscale_ptr[i]);\
|
||||
} \
|
||||
\
|
||||
if(is_zp_stor_type == TRUE) \
|
||||
{ \
|
||||
@@ -2086,7 +2224,7 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
goto err_handler; \
|
||||
} \
|
||||
C_DSCALE_type* temp_dzero_point_ptr = ( C_DSCALE_type* )( post_ops->sum )->zero_point; \
|
||||
GEN_FUNC_NAME(fill_array_,C_DSCALE_type)(temp_dzero_point_ptr, n_zp); \
|
||||
GEN_FUNC_NAME(fill_array_,C_DSCALE_type)(temp_dzero_point_ptr, n_zp); \
|
||||
( post_ops->sum )->zero_point_len = n_zp; \
|
||||
if(strcmp(#BLAS_SFX, "u8s8s32ou8")) for(dim_t i=0;i<n_zp;i++) temp_dzero_point_ptr[i] = abs(temp_dzero_point_ptr[i]);\
|
||||
} \
|
||||
|
||||
@@ -404,7 +404,7 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
|
||||
else if ( post_op->seq_vector[op_id] == SCALE ) \
|
||||
{ \
|
||||
temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_SFX) \
|
||||
(temp_accum, post_op, j); \
|
||||
(temp_accum, post_op, j, ( post_op->sum )->sf_stor_type); \
|
||||
} \
|
||||
else if ( post_op->seq_vector[op_id] == MATRIX_ADD ) \
|
||||
{ \
|
||||
|
||||
Reference in New Issue
Block a user