mirror of
https://github.com/amd/blis.git
synced 2026-04-20 15:48:50 +00:00
Added destination scale type check in INT8 API's
- Updated the S8 main, GEMV, m_, n_ and mn_ fringe kernels to support multiple scale types for vector and scalar scales - Updated the U8 main, GEMV, m_, n_, extMR_ and mn_ fringe kernels to support multiple scale types for vector and scalar scales - Updated the bench to accommodate multiple scale type input, and modified the downscale_accuracy_check_ to verify with multiple scale type inputs. AMD Internal: [ SWLCSG-3304 ] Change-Id: I7b9f3ec8ea830d3265f72d18a0aa36086e14a86e
This commit is contained in:
@@ -87,6 +87,7 @@ typedef struct
|
||||
dim_t scale_factor_len;
|
||||
dim_t zero_point_len;
|
||||
AOCL_PARAMS_STORAGE_TYPES zp_stor_type;
|
||||
AOCL_PARAMS_STORAGE_TYPES sf_stor_type;
|
||||
} aocl_post_op_sum; // Also use for scale.
|
||||
|
||||
typedef struct
|
||||
|
||||
@@ -299,7 +299,8 @@ err_t lpgemm_translate_to_pre_ops_list
|
||||
return BLIS_SUCCESS;
|
||||
}
|
||||
|
||||
BLIS_INLINE void lpgemm_set_node_params(
|
||||
BLIS_INLINE void lpgemm_set_node_params
|
||||
(
|
||||
lpgemm_post_op *post_op_node,
|
||||
LPGEMM_POST_OP_CODE op_code,
|
||||
void *op1,
|
||||
@@ -309,7 +310,9 @@ BLIS_INLINE void lpgemm_set_node_params(
|
||||
dim_t scale_factor_len,
|
||||
bool is_power_of_2,
|
||||
AOCL_STORAGE_TYPE stor_type,
|
||||
AOCL_STORAGE_TYPE zp_stor_type)
|
||||
AOCL_STORAGE_TYPE zp_stor_type,
|
||||
AOCL_STORAGE_TYPE sf_stor_type
|
||||
)
|
||||
{
|
||||
post_op_node->op_code = op_code;
|
||||
post_op_node->op_args1 = op1;
|
||||
@@ -320,6 +323,7 @@ BLIS_INLINE void lpgemm_set_node_params(
|
||||
post_op_node->is_power_of_2 = is_power_of_2;
|
||||
post_op_node->stor_type = stor_type;
|
||||
post_op_node->zp_stor_type = zp_stor_type;
|
||||
post_op_node->sf_stor_type = sf_stor_type;
|
||||
post_op_node->next = NULL;
|
||||
}
|
||||
|
||||
@@ -341,7 +345,8 @@ err_t lpgemm_translate_to_post_ops_list
|
||||
lpgemm_set_node_params
|
||||
(
|
||||
post_op_list, POST_OPS_DISABLE,
|
||||
NULL, NULL, NULL, NULL, 0, FALSE, NONE, NONE
|
||||
NULL, NULL, NULL, NULL, 0, FALSE, NONE,
|
||||
NONE, NONE
|
||||
);
|
||||
|
||||
return BLIS_SUCCESS;
|
||||
@@ -352,7 +357,8 @@ err_t lpgemm_translate_to_post_ops_list
|
||||
lpgemm_set_node_params
|
||||
(
|
||||
post_op_list, POST_OPS_DISABLE,
|
||||
NULL, NULL, NULL, NULL, 0, FALSE, NONE, NONE
|
||||
NULL, NULL, NULL, NULL, 0, FALSE, NONE,
|
||||
NONE, NONE
|
||||
);
|
||||
|
||||
bli_print_msg(" Max supported post-ops is 5, supplied input post-ops" \
|
||||
@@ -381,7 +387,7 @@ err_t lpgemm_translate_to_post_ops_list
|
||||
( post_op_unparsed->sum + s_i )->scale_factor,
|
||||
( post_op_unparsed->sum + s_i )->scale_factor_len,
|
||||
( post_op_unparsed->sum + s_i )->is_power_of_2,
|
||||
NONE, NONE
|
||||
NONE, NONE, NONE
|
||||
);
|
||||
|
||||
s_i += 1;
|
||||
@@ -445,7 +451,7 @@ err_t lpgemm_translate_to_post_ops_list
|
||||
( post_op_unparsed->eltwise + e_i )->scale_factor,
|
||||
( post_op_unparsed->eltwise + e_i )->scale_factor_len,
|
||||
( post_op_unparsed->eltwise + e_i )->is_power_of_2,
|
||||
NONE, NONE
|
||||
NONE, NONE, NONE
|
||||
);
|
||||
e_i += 1;
|
||||
}
|
||||
@@ -463,7 +469,7 @@ err_t lpgemm_translate_to_post_ops_list
|
||||
(
|
||||
( post_op_list + i ), POST_OPS_BIAS,
|
||||
( post_op_unparsed->bias + b_i )->bias,
|
||||
meta_arg, NULL, NULL, 0, FALSE, tmp_stor_type, NONE
|
||||
meta_arg, NULL, NULL, 0, FALSE, tmp_stor_type, NONE, NONE
|
||||
);
|
||||
|
||||
b_i += 1;
|
||||
@@ -502,6 +508,8 @@ err_t lpgemm_translate_to_post_ops_list
|
||||
|
||||
AOCL_STORAGE_TYPE tmp_zp_stor_type =
|
||||
get_stor_type( ( post_op_unparsed->sum + s_i )->zp_stor_type );
|
||||
AOCL_STORAGE_TYPE tmp_sf_stor_type =
|
||||
get_stor_type( ( post_op_unparsed->sum + s_i )->sf_stor_type );
|
||||
|
||||
lpgemm_set_node_params
|
||||
(
|
||||
@@ -510,7 +518,7 @@ err_t lpgemm_translate_to_post_ops_list
|
||||
meta_arg, &( ( post_op_unparsed->sum + s_i )->zero_point_len ),
|
||||
( post_op_unparsed->sum + s_i )->scale_factor,
|
||||
( post_op_unparsed->sum + s_i )->scale_factor_len,
|
||||
FALSE, NONE, tmp_zp_stor_type
|
||||
FALSE, NONE, tmp_zp_stor_type, tmp_sf_stor_type
|
||||
);
|
||||
|
||||
s_i += 1;
|
||||
@@ -535,7 +543,7 @@ err_t lpgemm_translate_to_post_ops_list
|
||||
meta_arg, &( ( post_op_unparsed->matrix_add + m_i )->ldm ),
|
||||
( post_op_unparsed->matrix_add + m_i )->scale_factor,
|
||||
( post_op_unparsed->matrix_add + m_i )->scale_factor_len,
|
||||
FALSE, tmp_stor_type, NONE
|
||||
FALSE, tmp_stor_type, NONE, NONE
|
||||
);
|
||||
|
||||
m_i += 1;
|
||||
@@ -560,7 +568,7 @@ err_t lpgemm_translate_to_post_ops_list
|
||||
meta_arg, &( ( post_op_unparsed->matrix_mul + mul_i )->ldm ),
|
||||
( post_op_unparsed->matrix_mul + mul_i )->scale_factor,
|
||||
( post_op_unparsed->matrix_mul + mul_i )->scale_factor_len,
|
||||
FALSE, tmp_stor_type, NONE
|
||||
FALSE, tmp_stor_type, NONE, NONE
|
||||
);
|
||||
|
||||
mul_i += 1;
|
||||
|
||||
@@ -68,6 +68,7 @@ typedef struct lpgemm_post_op_t
|
||||
bool is_power_of_2;
|
||||
uint64_t stor_type;
|
||||
uint64_t zp_stor_type;
|
||||
uint64_t sf_stor_type; //Introduced for sf store type
|
||||
struct lpgemm_post_op_t* next;
|
||||
} lpgemm_post_op;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user