mirror of
https://github.com/amd/blis.git
synced 2026-04-20 15:48:50 +00:00
Added support to specify bias data type in u8s8s32/s8s8s32 API's
Description: 1. The bias type was supported only based on output data type. 2. The option is added in the pre-ops structure to select the bias data type(s8/s32/bf16) irrespective of the storage data type in u8s8s32/s8s8s32 API's. AMD-Internal: SWLCSG-3302 Change-Id: I3c465fe428672d2d58c1c60115c46d2d5b11f0f4
This commit is contained in:
committed by
Nallani Bhaskar
parent
2f2741f4ab
commit
182a6373b5
@@ -764,9 +764,9 @@ GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os8)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os32)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int16_t,s8s8s16os8)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int16_t,s8s8s16os16)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,f32f32f32of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,bf16bf16f32of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,bf16s4f32of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL_f32(f32f32f32of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL_f32(bf16bf16f32of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL_f32(bf16s4f32of32)
|
||||
|
||||
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int32_t,int32_t)
|
||||
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int8_t,int32_t)
|
||||
|
||||
@@ -778,9 +778,9 @@ GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os8)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os32)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int16_t,s8s8s16os8)
|
||||
GEN_GET_BIAS_POST_OP_VAL(int16_t,s8s8s16os16)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,f32f32f32of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,bf16bf16f32of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL(float,bf16s4f32of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL_f32(f32f32f32of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL_f32(bf16bf16f32of32)
|
||||
GEN_GET_BIAS_POST_OP_VAL_f32(bf16s4f32of32)
|
||||
|
||||
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int32_t,int32_t)
|
||||
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int8_t,int32_t)
|
||||
|
||||
@@ -98,6 +98,18 @@ static inline void bfloat16_to_float( bfloat16 bf16_val, float* float_val )
|
||||
memcpy( float_val, &inter_temp, sizeof( int32_t ) );
|
||||
}
|
||||
|
||||
static inline void bfloat16_to_int32_t( bfloat16 bf16_val, int32_t* int_val )
|
||||
{
|
||||
int32_t inter_temp = *( ( int16_t* ) &bf16_val );
|
||||
inter_temp = inter_temp << 16;
|
||||
memcpy( int_val, &inter_temp, sizeof( int32_t ) );
|
||||
}
|
||||
|
||||
static inline void int8_t_to_int32_t( int8_t int8_t_val, int32_t* int_val )
|
||||
{
|
||||
*int_val = (int32_t)int8_t_val;
|
||||
}
|
||||
|
||||
static inline void convert_float_arr_to_bf16( float* array, bfloat16* array_bf16, dim_t size )
|
||||
{
|
||||
#ifdef BLIS_ENABLE_OPENMP
|
||||
@@ -211,6 +223,7 @@ static inline void fill_array_post_ops_ ## ctype ( void* arr, dim_t size ) \
|
||||
} \
|
||||
} \
|
||||
|
||||
GEN_FILL_ARRAY_POST_OPS_FUNC(int8_t)
|
||||
GEN_FILL_ARRAY_POST_OPS_FUNC(int16_t)
|
||||
GEN_FILL_ARRAY_POST_OPS_FUNC(int32_t)
|
||||
GEN_FILL_ARRAY_POST_OPS_FUNC(float)
|
||||
@@ -240,6 +253,23 @@ static inline float get_bias_post_op_val_ ## BLAS_SFX \
|
||||
return ret_val; \
|
||||
} \
|
||||
|
||||
#define GEN_GET_BIAS_POST_OP_VAL_f32(BLAS_SFX) \
|
||||
static inline float get_bias_post_op_val_ ## BLAS_SFX \
|
||||
( \
|
||||
void* post_op_bias_ptr, \
|
||||
dim_t j, \
|
||||
AOCL_PARAMS_STORAGE_TYPES bias_stor_type \
|
||||
) \
|
||||
{ \
|
||||
float ret_val = 0.0; \
|
||||
if(bias_stor_type == AOCL_GEMM_BF16) \
|
||||
{ \
|
||||
bfloat16_to_float( *( ( bfloat16* )post_op_bias_ptr + j ), &ret_val ); \
|
||||
return ret_val; \
|
||||
} \
|
||||
return *( ( float* )post_op_bias_ptr + j ); \
|
||||
} \
|
||||
|
||||
#define GEN_GET_BIAS_POST_OP_VAL(ACCUM_type,BLAS_SFX) \
|
||||
static inline ACCUM_type get_bias_post_op_val_ ## BLAS_SFX \
|
||||
( \
|
||||
@@ -250,8 +280,14 @@ static inline ACCUM_type get_bias_post_op_val_ ## BLAS_SFX \
|
||||
{ \
|
||||
if(bias_stor_type == AOCL_GEMM_BF16) \
|
||||
{ \
|
||||
float ret_val = 0.0; \
|
||||
bfloat16_to_float( *( ( bfloat16* )post_op_bias_ptr + j ), &ret_val ); \
|
||||
int32_t ret_val = 0.0; \
|
||||
bfloat16_to_int32_t( *( ( bfloat16* )post_op_bias_ptr + j ), &ret_val ); \
|
||||
return ret_val; \
|
||||
} \
|
||||
if(bias_stor_type == AOCL_GEMM_INT8) \
|
||||
{ \
|
||||
int32_t ret_val = 0.0; \
|
||||
int8_t_to_int32_t( *( ( int8_t* )post_op_bias_ptr + j ), &ret_val ); \
|
||||
return ret_val; \
|
||||
} \
|
||||
return *( ( ACCUM_type* )post_op_bias_ptr + j ); \
|
||||
@@ -729,6 +765,16 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
is_bias_stor_type = TRUE; \
|
||||
bias_stor_type = "BF16"; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "s32" ) == 0 ) ) \
|
||||
{ \
|
||||
is_bias_stor_type = TRUE; \
|
||||
bias_stor_type = "S32"; \
|
||||
} \
|
||||
else if ( ( strcmp( ops_tok, "s8" ) == 0 ) ) \
|
||||
{ \
|
||||
is_bias_stor_type = TRUE; \
|
||||
bias_stor_type = "S8"; \
|
||||
} \
|
||||
is_bias = TRUE; \
|
||||
cur_op_index++; \
|
||||
} \
|
||||
@@ -920,6 +966,16 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
( post_ops->bias )->stor_type = AOCL_GEMM_F32; \
|
||||
GEN_FUNC_NAME(fill_array_post_ops_,float)( ( post_ops->bias )->bias, n ); \
|
||||
} \
|
||||
else if( ( strcmp( bias_stor_type, "S8" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->bias )->stor_type = AOCL_GEMM_INT8; \
|
||||
GEN_FUNC_NAME(fill_array_post_ops_,int8_t)( ( post_ops->bias )->bias, n ); \
|
||||
} \
|
||||
else if( ( strcmp( bias_stor_type, "S32" ) == 0 ) ) \
|
||||
{ \
|
||||
( post_ops->bias )->stor_type = AOCL_GEMM_INT32; \
|
||||
GEN_FUNC_NAME(fill_array_post_ops_,int32_t)( ( post_ops->bias )->bias, n ); \
|
||||
} \
|
||||
else {} \
|
||||
} \
|
||||
else \
|
||||
|
||||
Reference in New Issue
Block a user