Added mutiple ZP type checks in INT8 APIs

- Currently the int8/uint8 APIs do not support multiple ZP types,
   but works only with int8 type or uint8 type.
 - The support is added to enable multiple zp types in these kernels
   and added additional macros to support the operations.
 - Modified the bench downscale reference code to support the updated
   types.

AMD-Internal : [ SWLCSG-3304 ]

Change-Id: Ia5e40ee3705a38d09262086d20731e8f0a126987
This commit is contained in:
varshav2
2025-04-07 00:20:07 +05:30
committed by Nallani Bhaskar
parent 267aae80ea
commit b9998a1d7f
19 changed files with 4848 additions and 2083 deletions

View File

@@ -616,7 +616,8 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
else if ( post_op[bs_i]->seq_vector[op_id] == SCALE ) \
{ \
post_temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_DOWNSCALE_SFX) \
(post_temp_accum, post_op[bs_i], j, ( post_op[bs_i]->sum )->sf_stor_type); \
(post_temp_accum, post_op[bs_i], j, ( post_op[bs_i]->sum )->sf_stor_type, \
( post_op[bs_i]->sum )->zp_stor_type); \
} \
else if ( post_op[bs_i]->seq_vector[op_id] == MATRIX_ADD ) \
{ \

View File

@@ -598,7 +598,8 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
else if ( post_op->seq_vector[op_id] == SCALE ) \
{ \
post_temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_DOWNSCALE_SFX) \
(post_temp_accum, post_op, j, ( post_op->sum )->sf_stor_type); \
(post_temp_accum, post_op, j, ( post_op->sum )->sf_stor_type, \
( post_op->sum )->zp_stor_type); \
} \
else if ( post_op->seq_vector[op_id] == MATRIX_ADD ) \
{ \

View File

@@ -209,42 +209,6 @@ static inline float eltwise_ops_accuracy_check_downscale_bf16obf16
return out_temp_accum;
}
static inline float convert_zp_store_type_to_float
(
aocl_post_op* post_op,
AOCL_PARAMS_STORAGE_TYPES zp_stor_type,
dim_t j_zp
)
{
float zp_float = 0.0;
if(zp_stor_type == AOCL_GEMM_BF16)
{
bfloat16_to_float( *( ( bfloat16* )( post_op->sum )->zero_point + j_zp ),
&zp_float );
}
else if(zp_stor_type == AOCL_GEMM_INT32)
{
int32_t_to_float( *( ( int32_t* )( post_op->sum )->zero_point + j_zp ),
&zp_float );
}
else if(zp_stor_type == AOCL_GEMM_INT8)
{
int8_t_to_float( *( ( int8_t* )( post_op->sum )->zero_point + j_zp ),
&zp_float );
}
else if(zp_stor_type == AOCL_GEMM_UINT8)
{
uint8_t_to_float( *( ( uint8_t* )( post_op->sum )->zero_point + j_zp ),
&zp_float );
}
else
{
zp_float = *( ( float* )( post_op->sum )->zero_point + j_zp );
}
return zp_float;
}
static inline float eltwise_ops_accuracy_check_downscale_f32of32
(
float temp_accum,

View File

@@ -834,13 +834,49 @@ float convert_scale_store_type_to_float
return scale_float;
}
float convert_zp_store_type_to_float
(
aocl_post_op* post_op,
AOCL_PARAMS_STORAGE_TYPES zp_stor_type,
dim_t j_zp
)
{
float zp_float = 0.0;
if(zp_stor_type == AOCL_GEMM_BF16)
{
bfloat16_to_float( *( ( bfloat16* )( post_op->sum )->zero_point + j_zp ),
&zp_float );
}
else if(zp_stor_type == AOCL_GEMM_INT32)
{
int32_t_to_float( *( ( int32_t* )( post_op->sum )->zero_point + j_zp ),
&zp_float );
}
else if(zp_stor_type == AOCL_GEMM_INT8 )
{
int8_t_to_float( *( ( int8_t* )( post_op->sum )->zero_point + j_zp ),
&zp_float );
}
else if(zp_stor_type == AOCL_GEMM_UINT8)
{
uint8_t_to_float( *( ( uint8_t* )( post_op->sum )->zero_point + j_zp ),
&zp_float );
}
else
{
zp_float = *( ( float* )( post_op->sum )->zero_point + j_zp );
}
return zp_float;
}
#define GEN_MAT_MUL_ACC_CHK_DOWNSCALE(ZP_type,C_type,ACCUM_type,SCALE_type,BLAS_DOWNSCALE_SFX) \
static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX \
(\
ACCUM_type temp_accum,\
aocl_post_op* post_op, \
dim_t j, \
AOCL_PARAMS_STORAGE_TYPES sf_stor_type \
AOCL_PARAMS_STORAGE_TYPES sf_stor_type, \
AOCL_PARAMS_STORAGE_TYPES zp_stor_type \
)\
{ \
dim_t j_scale = j; \
@@ -855,12 +891,20 @@ static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX
j_zp = 0; \
} \
\
float temp_zp; \
float temp_sf = convert_scale_store_type_to_float(post_op, sf_stor_type, j_scale); \
if( zp_stor_type != NULLTYPE ) \
{ \
temp_zp = convert_zp_store_type_to_float(post_op, zp_stor_type, j_zp); \
} \
else \
{ \
temp_zp = *( ( ZP_type* )( post_op->sum )->zero_point + j_zp ); \
} \
ACCUM_type out_temp_accum = \
( ACCUM_type )min( \
max( nearbyintf( ( SCALE_type )( temp_accum ) * \
( temp_sf ) ) + \
*( ( ZP_type* )( post_op->sum )->zero_point + j_zp ), \
( temp_sf ) ) + temp_zp, \
DSCALE_CLIP_MIN ), \
DSCALE_CLIP_MAX ); \
return out_temp_accum; \
@@ -872,7 +916,8 @@ static inline float mat_mul_accuracy_check_downscale_bf16bf16f32obf16
float temp_accum,
aocl_post_op* post_op,
dim_t j,
AOCL_PARAMS_STORAGE_TYPES sf_stor_type
AOCL_PARAMS_STORAGE_TYPES sf_stor_type,
AOCL_PARAMS_STORAGE_TYPES zp_stor_type
)
{
( void ) sf_stor_type;
@@ -901,7 +946,8 @@ static inline float mat_mul_accuracy_check_downscale_f32f32f32of32
float temp_accum,
aocl_post_op* post_op,
dim_t j,
AOCL_PARAMS_STORAGE_TYPES sf_stor_type
AOCL_PARAMS_STORAGE_TYPES sf_stor_type,
AOCL_PARAMS_STORAGE_TYPES zp_stor_type
)
{
( void ) sf_stor_type;
@@ -1625,7 +1671,7 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
} \
else if ( strcmp( ops_tok, "zp_stor_type" ) == 0) \
{ \
ops_tok = strtok( NULL, ", " ); \
ops_tok = strtok( NULL, ", " ); \
if( ( strcmp( ops_tok, "na" ) == 0 ) ) \
{ \
is_zp_stor_type = FALSE; \

View File

@@ -404,7 +404,8 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
else if ( post_op->seq_vector[op_id] == SCALE ) \
{ \
temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_SFX) \
(temp_accum, post_op, j, ( post_op->sum )->sf_stor_type); \
(temp_accum, post_op, j, ( post_op->sum )->sf_stor_type, \
( post_op->sum )->zp_stor_type); \
} \
else if ( post_op->seq_vector[op_id] == MATRIX_ADD ) \
{ \