Added destination scale type check in INT8 API's

- Updated the S8 main, GEMV, m_, n_ and mn_ fringe kernels to support
   multiple scale types for vector and scalar scales

 - Updated the U8 main, GEMV, m_, n_, extMR_ and mn_ fringe kernels to
   support multiple scale types for vector and scalar scales

 - Updated the bench to accommodate multiple scale type input, and
   modified the downscale_accuracy_check_ to verify with multiple scale
   type inputs.

AMD Internal: [ SWLCSG-3304 ]

Change-Id: I7b9f3ec8ea830d3265f72d18a0aa36086e14a86e
This commit is contained in:
varshav
2025-03-24 06:30:07 +00:00
committed by Nallani Bhaskar
parent 350c7186e5
commit 81d219e3f8
21 changed files with 4469 additions and 1094 deletions

View File

@@ -87,6 +87,7 @@ typedef struct
dim_t scale_factor_len;
dim_t zero_point_len;
AOCL_PARAMS_STORAGE_TYPES zp_stor_type;
AOCL_PARAMS_STORAGE_TYPES sf_stor_type;
} aocl_post_op_sum; // Also use for scale.
typedef struct

View File

@@ -299,7 +299,8 @@ err_t lpgemm_translate_to_pre_ops_list
return BLIS_SUCCESS;
}
BLIS_INLINE void lpgemm_set_node_params(
BLIS_INLINE void lpgemm_set_node_params
(
lpgemm_post_op *post_op_node,
LPGEMM_POST_OP_CODE op_code,
void *op1,
@@ -309,7 +310,9 @@ BLIS_INLINE void lpgemm_set_node_params(
dim_t scale_factor_len,
bool is_power_of_2,
AOCL_STORAGE_TYPE stor_type,
AOCL_STORAGE_TYPE zp_stor_type)
AOCL_STORAGE_TYPE zp_stor_type,
AOCL_STORAGE_TYPE sf_stor_type
)
{
post_op_node->op_code = op_code;
post_op_node->op_args1 = op1;
@@ -320,6 +323,7 @@ BLIS_INLINE void lpgemm_set_node_params(
post_op_node->is_power_of_2 = is_power_of_2;
post_op_node->stor_type = stor_type;
post_op_node->zp_stor_type = zp_stor_type;
post_op_node->sf_stor_type = sf_stor_type;
post_op_node->next = NULL;
}
@@ -341,7 +345,8 @@ err_t lpgemm_translate_to_post_ops_list
lpgemm_set_node_params
(
post_op_list, POST_OPS_DISABLE,
NULL, NULL, NULL, NULL, 0, FALSE, NONE, NONE
NULL, NULL, NULL, NULL, 0, FALSE, NONE,
NONE, NONE
);
return BLIS_SUCCESS;
@@ -352,7 +357,8 @@ err_t lpgemm_translate_to_post_ops_list
lpgemm_set_node_params
(
post_op_list, POST_OPS_DISABLE,
NULL, NULL, NULL, NULL, 0, FALSE, NONE, NONE
NULL, NULL, NULL, NULL, 0, FALSE, NONE,
NONE, NONE
);
bli_print_msg(" Max supported post-ops is 5, supplied input post-ops" \
@@ -381,7 +387,7 @@ err_t lpgemm_translate_to_post_ops_list
( post_op_unparsed->sum + s_i )->scale_factor,
( post_op_unparsed->sum + s_i )->scale_factor_len,
( post_op_unparsed->sum + s_i )->is_power_of_2,
NONE, NONE
NONE, NONE, NONE
);
s_i += 1;
@@ -445,7 +451,7 @@ err_t lpgemm_translate_to_post_ops_list
( post_op_unparsed->eltwise + e_i )->scale_factor,
( post_op_unparsed->eltwise + e_i )->scale_factor_len,
( post_op_unparsed->eltwise + e_i )->is_power_of_2,
NONE, NONE
NONE, NONE, NONE
);
e_i += 1;
}
@@ -463,7 +469,7 @@ err_t lpgemm_translate_to_post_ops_list
(
( post_op_list + i ), POST_OPS_BIAS,
( post_op_unparsed->bias + b_i )->bias,
meta_arg, NULL, NULL, 0, FALSE, tmp_stor_type, NONE
meta_arg, NULL, NULL, 0, FALSE, tmp_stor_type, NONE, NONE
);
b_i += 1;
@@ -502,6 +508,8 @@ err_t lpgemm_translate_to_post_ops_list
AOCL_STORAGE_TYPE tmp_zp_stor_type =
get_stor_type( ( post_op_unparsed->sum + s_i )->zp_stor_type );
AOCL_STORAGE_TYPE tmp_sf_stor_type =
get_stor_type( ( post_op_unparsed->sum + s_i )->sf_stor_type );
lpgemm_set_node_params
(
@@ -510,7 +518,7 @@ err_t lpgemm_translate_to_post_ops_list
meta_arg, &( ( post_op_unparsed->sum + s_i )->zero_point_len ),
( post_op_unparsed->sum + s_i )->scale_factor,
( post_op_unparsed->sum + s_i )->scale_factor_len,
FALSE, NONE, tmp_zp_stor_type
FALSE, NONE, tmp_zp_stor_type, tmp_sf_stor_type
);
s_i += 1;
@@ -535,7 +543,7 @@ err_t lpgemm_translate_to_post_ops_list
meta_arg, &( ( post_op_unparsed->matrix_add + m_i )->ldm ),
( post_op_unparsed->matrix_add + m_i )->scale_factor,
( post_op_unparsed->matrix_add + m_i )->scale_factor_len,
FALSE, tmp_stor_type, NONE
FALSE, tmp_stor_type, NONE, NONE
);
m_i += 1;
@@ -560,7 +568,7 @@ err_t lpgemm_translate_to_post_ops_list
meta_arg, &( ( post_op_unparsed->matrix_mul + mul_i )->ldm ),
( post_op_unparsed->matrix_mul + mul_i )->scale_factor,
( post_op_unparsed->matrix_mul + mul_i )->scale_factor_len,
FALSE, tmp_stor_type, NONE
FALSE, tmp_stor_type, NONE, NONE
);
mul_i += 1;

View File

@@ -68,6 +68,7 @@ typedef struct lpgemm_post_op_t
bool is_power_of_2;
uint64_t stor_type;
uint64_t zp_stor_type;
uint64_t sf_stor_type; //Introduced for sf store type
struct lpgemm_post_op_t* next;
} lpgemm_post_op;