mirror of
https://github.com/amd/blis.git
synced 2026-04-21 08:08:51 +00:00
Multi-data type buffer and scale support for matrix add|mul post-ops in s32 API.
-As it stands the buffer type in matrix add|mul post-ops is expected to be the same as that of the output C matrix type. This limitation is now removed and user can specify the buffer type by setting the stor_type attribute in add|mul post-op struct. As of now int8, int32, bfloat16 and float types are supported for the buffer in s32 micro-kernels. The same support is also added for bf16 micro-kernels, with bfloat16 and float supported for now. -Additionally the values (from buffer) are added/multiplied as is to the output registers while performing the matrix add|mul post-ops. Support is added for scaling these values before using them in the post-ops. Both scalar and vector scale_factors are supported. -The bias_stor_type attribute is renamed to stor_type in bias post-ops. AMD-Internal: [SWLCSG-3319] Change-Id: I4046ab84481b02c55a71ebb7038e38aec840c0fa
This commit is contained in:
committed by
MithunMohan KadavilMadanaMohanan
parent
051c9ac7a2
commit
ef4286a97e
@@ -872,7 +872,7 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
|
||||
if ( post_op[bs_i]->seq_vector[op_id] == BIAS ) \
|
||||
{ \
|
||||
temp_accum += GEN_FUNC_NAME(get_bias_post_op_val_,BLAS_SFX) \
|
||||
( ( post_op[bs_i]->bias )->bias, j, ( post_op[bs_i]->bias )->bias_stor_type ); \
|
||||
( ( post_op[bs_i]->bias )->bias, j, ( post_op[bs_i]->bias )->stor_type ); \
|
||||
} \
|
||||
else if ( post_op[bs_i]->seq_vector[op_id] == ELTWISE ) \
|
||||
{ \
|
||||
|
||||
@@ -893,7 +893,7 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
|
||||
if ( post_op->seq_vector[op_id] == BIAS ) \
|
||||
{ \
|
||||
temp_accum += GEN_FUNC_NAME(get_bias_post_op_val_,BLAS_SFX) \
|
||||
( ( post_op->bias )->bias, j, ( post_op->bias )->bias_stor_type ); \
|
||||
( ( post_op->bias )->bias, j, ( post_op->bias )->stor_type ); \
|
||||
} \
|
||||
else if ( post_op->seq_vector[op_id] == ELTWISE ) \
|
||||
{ \
|
||||
|
||||
@@ -286,7 +286,7 @@ void eltwise_ops_accuracy_check_driver_ ## LP_SFX \
|
||||
if ( post_op->seq_vector[op_id] == BIAS ) \
|
||||
{ \
|
||||
temp_accum += GEN_FUNC_NAME(get_bias_post_op_val_,LP_SFX) \
|
||||
( ( post_op->bias )->bias, j, ( post_op->bias )->bias_stor_type ); \
|
||||
( ( post_op->bias )->bias, j, ( post_op->bias )->stor_type ); \
|
||||
} \
|
||||
else if ( post_op->seq_vector[op_id] == ELTWISE ) \
|
||||
{ \
|
||||
@@ -743,17 +743,17 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
{ \
|
||||
if( ( strcmp( bias_stor_type, "BF16" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->bias )-> bias_stor_type = AOCL_GEMM_BF16; \
|
||||
( post_ops->bias )->stor_type = AOCL_GEMM_BF16; \
|
||||
} \
|
||||
else if( ( strcmp( bias_stor_type, "F32" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->bias )-> bias_stor_type = AOCL_GEMM_F32; \
|
||||
( post_ops->bias )->stor_type = AOCL_GEMM_F32; \
|
||||
} \
|
||||
else {} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
( post_ops->bias )-> bias_stor_type = NULLTYPE ; \
|
||||
( post_ops->bias )->stor_type = NULLTYPE ; \
|
||||
} \
|
||||
if ( global_dscale_out == 'y' ) \
|
||||
{ \
|
||||
|
||||
@@ -228,11 +228,11 @@ static inline float get_bias_post_op_val_ ## BLAS_SFX \
|
||||
( \
|
||||
void* post_op_bias_ptr, \
|
||||
dim_t j, \
|
||||
AOCL_PARAMS_STORAGE_TYPES bais_stor_type \
|
||||
AOCL_PARAMS_STORAGE_TYPES bias_stor_type \
|
||||
) \
|
||||
{ \
|
||||
float ret_val = 0.0; \
|
||||
if(bais_stor_type == AOCL_GEMM_F32) \
|
||||
if(bias_stor_type == AOCL_GEMM_F32) \
|
||||
{ \
|
||||
return *( ( float* )post_op_bias_ptr + j ); \
|
||||
} \
|
||||
@@ -245,10 +245,10 @@ static inline ACCUM_type get_bias_post_op_val_ ## BLAS_SFX \
|
||||
( \
|
||||
void* post_op_bias_ptr, \
|
||||
dim_t j, \
|
||||
AOCL_PARAMS_STORAGE_TYPES bais_stor_type \
|
||||
AOCL_PARAMS_STORAGE_TYPES bias_stor_type \
|
||||
) \
|
||||
{ \
|
||||
if(bais_stor_type == AOCL_GEMM_BF16) \
|
||||
if(bias_stor_type == AOCL_GEMM_BF16) \
|
||||
{ \
|
||||
float ret_val = 0.0; \
|
||||
bfloat16_to_float( *( ( bfloat16* )post_op_bias_ptr + j ), &ret_val ); \
|
||||
@@ -425,9 +425,9 @@ static inline float get_matrix_mul_post_op_val_ ## BLAS_SFX \
|
||||
) \
|
||||
{ \
|
||||
return GEN_FUNC_NAME(get_matrix_add_post_op_val_,BLAS_SFX) \
|
||||
( \
|
||||
val, j, scl_fctr, scl_fctr_len \
|
||||
); \
|
||||
( \
|
||||
val, j, scl_fctr, scl_fctr_len \
|
||||
); \
|
||||
} \
|
||||
|
||||
#define GEN_GET_MATRIX_MUL_POST_OP_VAL(C_type,ACCUM_type,BLAS_SFX) \
|
||||
@@ -440,9 +440,9 @@ static inline ACCUM_type get_matrix_mul_post_op_val_ ## BLAS_SFX \
|
||||
) \
|
||||
{ \
|
||||
return GEN_FUNC_NAME(get_matrix_add_post_op_val_,BLAS_SFX) \
|
||||
( \
|
||||
val, j, scl_fctr, scl_fctr_len \
|
||||
); \
|
||||
( \
|
||||
val, j, scl_fctr, scl_fctr_len \
|
||||
); \
|
||||
} \
|
||||
|
||||
/* Final output type value getter. */
|
||||
@@ -912,21 +912,24 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
{ \
|
||||
if( ( strcmp( bias_stor_type, "BF16" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->bias )-> bias_stor_type = AOCL_GEMM_BF16; \
|
||||
( post_ops->bias )->stor_type = AOCL_GEMM_BF16; \
|
||||
GEN_FUNC_NAME(fill_array_post_ops_,bfloat16)( ( post_ops->bias )->bias, n ); \
|
||||
} \
|
||||
else if( ( strcmp( bias_stor_type, "F32" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->bias )-> bias_stor_type = AOCL_GEMM_F32; \
|
||||
( post_ops->bias )->stor_type = AOCL_GEMM_F32; \
|
||||
GEN_FUNC_NAME(fill_array_post_ops_,float)( ( post_ops->bias )->bias, n ); \
|
||||
} \
|
||||
else {} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
( post_ops->bias )-> bias_stor_type = NULLTYPE; \
|
||||
( post_ops->bias )->stor_type = NULLTYPE; \
|
||||
if( global_dscale_out == 'y') \
|
||||
{ \
|
||||
if ( strcmp(#BIAS_type, "bfloat16") == 0 ) { \
|
||||
( post_ops->bias )->stor_type = AOCL_GEMM_BF16; \
|
||||
} \
|
||||
GEN_FUNC_NAME(fill_array_post_ops_,BIAS_type)( ( post_ops->bias )->bias, n ); \
|
||||
} \
|
||||
else \
|
||||
@@ -1144,6 +1147,9 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
temp_dscale_ptr[i] = ( ( DSCALE_type )2 ); \
|
||||
} \
|
||||
( post_ops->matrix_add )->scale_factor_len = n_scale; \
|
||||
/* Set buffer type same as c_store type for now.
|
||||
* TODO: Update to cover more data types. */ \
|
||||
( post_ops->matrix_add )->stor_type = NULLTYPE; \
|
||||
} \
|
||||
\
|
||||
if ( is_matrix_mul == TRUE ) \
|
||||
@@ -1194,6 +1200,9 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
temp_dscale_ptr[i] = ( ( DSCALE_type )2 ); \
|
||||
} \
|
||||
( post_ops->matrix_mul )->scale_factor_len = n_scale; \
|
||||
/* Set buffer type same as c_store type for now.
|
||||
* TODO: Update to cover more data types. */ \
|
||||
( post_ops->matrix_mul )->stor_type = NULLTYPE; \
|
||||
} \
|
||||
\
|
||||
post_ops->seq_length = cur_op_index; \
|
||||
|
||||
Reference in New Issue
Block a user