mirror of
https://github.com/amd/blis.git
synced 2026-04-20 15:48:50 +00:00
Added mutiple ZP type checks in INT8 APIs
- Currently the int8/uint8 APIs do not support multiple ZP types, but works only with int8 type or uint8 type. - The support is added to enable multiple zp types in these kernels and added additional macros to support the operations. - Modified the bench downscale reference code to support the updated types. AMD-Internal : [ SWLCSG-3304 ] Change-Id: Ia5e40ee3705a38d09262086d20731e8f0a126987
This commit is contained in:
committed by
Nallani Bhaskar
parent
267aae80ea
commit
b9998a1d7f
@@ -616,7 +616,8 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
|
||||
else if ( post_op[bs_i]->seq_vector[op_id] == SCALE ) \
|
||||
{ \
|
||||
post_temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_DOWNSCALE_SFX) \
|
||||
(post_temp_accum, post_op[bs_i], j, ( post_op[bs_i]->sum )->sf_stor_type); \
|
||||
(post_temp_accum, post_op[bs_i], j, ( post_op[bs_i]->sum )->sf_stor_type, \
|
||||
( post_op[bs_i]->sum )->zp_stor_type); \
|
||||
} \
|
||||
else if ( post_op[bs_i]->seq_vector[op_id] == MATRIX_ADD ) \
|
||||
{ \
|
||||
|
||||
@@ -598,7 +598,8 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
|
||||
else if ( post_op->seq_vector[op_id] == SCALE ) \
|
||||
{ \
|
||||
post_temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_DOWNSCALE_SFX) \
|
||||
(post_temp_accum, post_op, j, ( post_op->sum )->sf_stor_type); \
|
||||
(post_temp_accum, post_op, j, ( post_op->sum )->sf_stor_type, \
|
||||
( post_op->sum )->zp_stor_type); \
|
||||
} \
|
||||
else if ( post_op->seq_vector[op_id] == MATRIX_ADD ) \
|
||||
{ \
|
||||
|
||||
@@ -209,42 +209,6 @@ static inline float eltwise_ops_accuracy_check_downscale_bf16obf16
|
||||
return out_temp_accum;
|
||||
}
|
||||
|
||||
static inline float convert_zp_store_type_to_float
|
||||
(
|
||||
aocl_post_op* post_op,
|
||||
AOCL_PARAMS_STORAGE_TYPES zp_stor_type,
|
||||
dim_t j_zp
|
||||
)
|
||||
{
|
||||
float zp_float = 0.0;
|
||||
if(zp_stor_type == AOCL_GEMM_BF16)
|
||||
{
|
||||
bfloat16_to_float( *( ( bfloat16* )( post_op->sum )->zero_point + j_zp ),
|
||||
&zp_float );
|
||||
}
|
||||
else if(zp_stor_type == AOCL_GEMM_INT32)
|
||||
{
|
||||
int32_t_to_float( *( ( int32_t* )( post_op->sum )->zero_point + j_zp ),
|
||||
&zp_float );
|
||||
}
|
||||
else if(zp_stor_type == AOCL_GEMM_INT8)
|
||||
{
|
||||
int8_t_to_float( *( ( int8_t* )( post_op->sum )->zero_point + j_zp ),
|
||||
&zp_float );
|
||||
}
|
||||
else if(zp_stor_type == AOCL_GEMM_UINT8)
|
||||
{
|
||||
uint8_t_to_float( *( ( uint8_t* )( post_op->sum )->zero_point + j_zp ),
|
||||
&zp_float );
|
||||
}
|
||||
else
|
||||
{
|
||||
zp_float = *( ( float* )( post_op->sum )->zero_point + j_zp );
|
||||
}
|
||||
return zp_float;
|
||||
}
|
||||
|
||||
|
||||
static inline float eltwise_ops_accuracy_check_downscale_f32of32
|
||||
(
|
||||
float temp_accum,
|
||||
|
||||
@@ -834,13 +834,49 @@ float convert_scale_store_type_to_float
|
||||
return scale_float;
|
||||
}
|
||||
|
||||
float convert_zp_store_type_to_float
|
||||
(
|
||||
aocl_post_op* post_op,
|
||||
AOCL_PARAMS_STORAGE_TYPES zp_stor_type,
|
||||
dim_t j_zp
|
||||
)
|
||||
{
|
||||
float zp_float = 0.0;
|
||||
if(zp_stor_type == AOCL_GEMM_BF16)
|
||||
{
|
||||
bfloat16_to_float( *( ( bfloat16* )( post_op->sum )->zero_point + j_zp ),
|
||||
&zp_float );
|
||||
}
|
||||
else if(zp_stor_type == AOCL_GEMM_INT32)
|
||||
{
|
||||
int32_t_to_float( *( ( int32_t* )( post_op->sum )->zero_point + j_zp ),
|
||||
&zp_float );
|
||||
}
|
||||
else if(zp_stor_type == AOCL_GEMM_INT8 )
|
||||
{
|
||||
int8_t_to_float( *( ( int8_t* )( post_op->sum )->zero_point + j_zp ),
|
||||
&zp_float );
|
||||
}
|
||||
else if(zp_stor_type == AOCL_GEMM_UINT8)
|
||||
{
|
||||
uint8_t_to_float( *( ( uint8_t* )( post_op->sum )->zero_point + j_zp ),
|
||||
&zp_float );
|
||||
}
|
||||
else
|
||||
{
|
||||
zp_float = *( ( float* )( post_op->sum )->zero_point + j_zp );
|
||||
}
|
||||
return zp_float;
|
||||
}
|
||||
|
||||
#define GEN_MAT_MUL_ACC_CHK_DOWNSCALE(ZP_type,C_type,ACCUM_type,SCALE_type,BLAS_DOWNSCALE_SFX) \
|
||||
static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX \
|
||||
(\
|
||||
ACCUM_type temp_accum,\
|
||||
aocl_post_op* post_op, \
|
||||
dim_t j, \
|
||||
AOCL_PARAMS_STORAGE_TYPES sf_stor_type \
|
||||
AOCL_PARAMS_STORAGE_TYPES sf_stor_type, \
|
||||
AOCL_PARAMS_STORAGE_TYPES zp_stor_type \
|
||||
)\
|
||||
{ \
|
||||
dim_t j_scale = j; \
|
||||
@@ -855,12 +891,20 @@ static inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX
|
||||
j_zp = 0; \
|
||||
} \
|
||||
\
|
||||
float temp_zp; \
|
||||
float temp_sf = convert_scale_store_type_to_float(post_op, sf_stor_type, j_scale); \
|
||||
if( zp_stor_type != NULLTYPE ) \
|
||||
{ \
|
||||
temp_zp = convert_zp_store_type_to_float(post_op, zp_stor_type, j_zp); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
temp_zp = *( ( ZP_type* )( post_op->sum )->zero_point + j_zp ); \
|
||||
} \
|
||||
ACCUM_type out_temp_accum = \
|
||||
( ACCUM_type )min( \
|
||||
max( nearbyintf( ( SCALE_type )( temp_accum ) * \
|
||||
( temp_sf ) ) + \
|
||||
*( ( ZP_type* )( post_op->sum )->zero_point + j_zp ), \
|
||||
( temp_sf ) ) + temp_zp, \
|
||||
DSCALE_CLIP_MIN ), \
|
||||
DSCALE_CLIP_MAX ); \
|
||||
return out_temp_accum; \
|
||||
@@ -872,7 +916,8 @@ static inline float mat_mul_accuracy_check_downscale_bf16bf16f32obf16
|
||||
float temp_accum,
|
||||
aocl_post_op* post_op,
|
||||
dim_t j,
|
||||
AOCL_PARAMS_STORAGE_TYPES sf_stor_type
|
||||
AOCL_PARAMS_STORAGE_TYPES sf_stor_type,
|
||||
AOCL_PARAMS_STORAGE_TYPES zp_stor_type
|
||||
)
|
||||
{
|
||||
( void ) sf_stor_type;
|
||||
@@ -901,7 +946,8 @@ static inline float mat_mul_accuracy_check_downscale_f32f32f32of32
|
||||
float temp_accum,
|
||||
aocl_post_op* post_op,
|
||||
dim_t j,
|
||||
AOCL_PARAMS_STORAGE_TYPES sf_stor_type
|
||||
AOCL_PARAMS_STORAGE_TYPES sf_stor_type,
|
||||
AOCL_PARAMS_STORAGE_TYPES zp_stor_type
|
||||
)
|
||||
{
|
||||
( void ) sf_stor_type;
|
||||
@@ -1625,7 +1671,7 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
} \
|
||||
else if ( strcmp( ops_tok, "zp_stor_type" ) == 0) \
|
||||
{ \
|
||||
ops_tok = strtok( NULL, ", " ); \
|
||||
ops_tok = strtok( NULL, ", " ); \
|
||||
if( ( strcmp( ops_tok, "na" ) == 0 ) ) \
|
||||
{ \
|
||||
is_zp_stor_type = FALSE; \
|
||||
|
||||
@@ -404,7 +404,8 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
|
||||
else if ( post_op->seq_vector[op_id] == SCALE ) \
|
||||
{ \
|
||||
temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_SFX) \
|
||||
(temp_accum, post_op, j, ( post_op->sum )->sf_stor_type); \
|
||||
(temp_accum, post_op, j, ( post_op->sum )->sf_stor_type, \
|
||||
( post_op->sum )->zp_stor_type); \
|
||||
} \
|
||||
else if ( post_op->seq_vector[op_id] == MATRIX_ADD ) \
|
||||
{ \
|
||||
|
||||
Reference in New Issue
Block a user