Added support to specify bias data type in u8s8s32/s8s8s32 API's

Description:
1. The bias type was supported only based on output data type.
2. The option is added in the pre-ops structure to select the bias data
   type(s8/s32/bf16) irrespective of the storage data type in
   u8s8s32/s8s8s32 API's.

AMD-Internal: SWLCSG-3302

Change-Id: I3c465fe428672d2d58c1c60115c46d2d5b11f0f4
This commit is contained in:
Deepak Negi
2025-01-14 08:05:22 +05:30
committed by Nallani Bhaskar
parent 2f2741f4ab
commit 182a6373b5
17 changed files with 1491 additions and 352 deletions

View File

@@ -764,9 +764,9 @@ GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os8)
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os32)
GEN_GET_BIAS_POST_OP_VAL(int16_t,s8s8s16os8)
GEN_GET_BIAS_POST_OP_VAL(int16_t,s8s8s16os16)
GEN_GET_BIAS_POST_OP_VAL(float,f32f32f32of32)
GEN_GET_BIAS_POST_OP_VAL(float,bf16bf16f32of32)
GEN_GET_BIAS_POST_OP_VAL(float,bf16s4f32of32)
GEN_GET_BIAS_POST_OP_VAL_f32(f32f32f32of32)
GEN_GET_BIAS_POST_OP_VAL_f32(bf16bf16f32of32)
GEN_GET_BIAS_POST_OP_VAL_f32(bf16s4f32of32)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int32_t,int32_t)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int8_t,int32_t)

View File

@@ -778,9 +778,9 @@ GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os8)
GEN_GET_BIAS_POST_OP_VAL(int32_t,s8s8s32os32)
GEN_GET_BIAS_POST_OP_VAL(int16_t,s8s8s16os8)
GEN_GET_BIAS_POST_OP_VAL(int16_t,s8s8s16os16)
GEN_GET_BIAS_POST_OP_VAL(float,f32f32f32of32)
GEN_GET_BIAS_POST_OP_VAL(float,bf16bf16f32of32)
GEN_GET_BIAS_POST_OP_VAL(float,bf16s4f32of32)
GEN_GET_BIAS_POST_OP_VAL_f32(f32f32f32of32)
GEN_GET_BIAS_POST_OP_VAL_f32(bf16bf16f32of32)
GEN_GET_BIAS_POST_OP_VAL_f32(bf16s4f32of32)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int32_t,int32_t)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int8_t,int32_t)

View File

@@ -98,6 +98,18 @@ static inline void bfloat16_to_float( bfloat16 bf16_val, float* float_val )
memcpy( float_val, &inter_temp, sizeof( int32_t ) );
}
static inline void bfloat16_to_int32_t( bfloat16 bf16_val, int32_t* int_val )
{
int32_t inter_temp = *( ( int16_t* ) &bf16_val );
inter_temp = inter_temp << 16;
memcpy( int_val, &inter_temp, sizeof( int32_t ) );
}
static inline void int8_t_to_int32_t( int8_t int8_t_val, int32_t* int_val )
{
*int_val = (int32_t)int8_t_val;
}
static inline void convert_float_arr_to_bf16( float* array, bfloat16* array_bf16, dim_t size )
{
#ifdef BLIS_ENABLE_OPENMP
@@ -211,6 +223,7 @@ static inline void fill_array_post_ops_ ## ctype ( void* arr, dim_t size ) \
} \
} \
GEN_FILL_ARRAY_POST_OPS_FUNC(int8_t)
GEN_FILL_ARRAY_POST_OPS_FUNC(int16_t)
GEN_FILL_ARRAY_POST_OPS_FUNC(int32_t)
GEN_FILL_ARRAY_POST_OPS_FUNC(float)
@@ -240,6 +253,23 @@ static inline float get_bias_post_op_val_ ## BLAS_SFX \
return ret_val; \
} \
#define GEN_GET_BIAS_POST_OP_VAL_f32(BLAS_SFX) \
static inline float get_bias_post_op_val_ ## BLAS_SFX \
( \
void* post_op_bias_ptr, \
dim_t j, \
AOCL_PARAMS_STORAGE_TYPES bias_stor_type \
) \
{ \
float ret_val = 0.0; \
if(bias_stor_type == AOCL_GEMM_BF16) \
{ \
bfloat16_to_float( *( ( bfloat16* )post_op_bias_ptr + j ), &ret_val ); \
return ret_val; \
} \
return *( ( float* )post_op_bias_ptr + j ); \
} \
#define GEN_GET_BIAS_POST_OP_VAL(ACCUM_type,BLAS_SFX) \
static inline ACCUM_type get_bias_post_op_val_ ## BLAS_SFX \
( \
@@ -250,8 +280,14 @@ static inline ACCUM_type get_bias_post_op_val_ ## BLAS_SFX \
{ \
if(bias_stor_type == AOCL_GEMM_BF16) \
{ \
float ret_val = 0.0; \
bfloat16_to_float( *( ( bfloat16* )post_op_bias_ptr + j ), &ret_val ); \
int32_t ret_val = 0.0; \
bfloat16_to_int32_t( *( ( bfloat16* )post_op_bias_ptr + j ), &ret_val ); \
return ret_val; \
} \
if(bias_stor_type == AOCL_GEMM_INT8) \
{ \
int32_t ret_val = 0.0; \
int8_t_to_int32_t( *( ( int8_t* )post_op_bias_ptr + j ), &ret_val ); \
return ret_val; \
} \
return *( ( ACCUM_type* )post_op_bias_ptr + j ); \
@@ -729,6 +765,16 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
is_bias_stor_type = TRUE; \
bias_stor_type = "BF16"; \
} \
else if ( ( strcmp( ops_tok, "s32" ) == 0 ) ) \
{ \
is_bias_stor_type = TRUE; \
bias_stor_type = "S32"; \
} \
else if ( ( strcmp( ops_tok, "s8" ) == 0 ) ) \
{ \
is_bias_stor_type = TRUE; \
bias_stor_type = "S8"; \
} \
is_bias = TRUE; \
cur_op_index++; \
} \
@@ -920,6 +966,16 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
( post_ops->bias )->stor_type = AOCL_GEMM_F32; \
GEN_FUNC_NAME(fill_array_post_ops_,float)( ( post_ops->bias )->bias, n ); \
} \
else if( ( strcmp( bias_stor_type, "S8" ) == 0 ) ) \
{ \
( post_ops->bias )->stor_type = AOCL_GEMM_INT8; \
GEN_FUNC_NAME(fill_array_post_ops_,int8_t)( ( post_ops->bias )->bias, n ); \
} \
else if( ( strcmp( bias_stor_type, "S32" ) == 0 ) ) \
{ \
( post_ops->bias )->stor_type = AOCL_GEMM_INT32; \
GEN_FUNC_NAME(fill_array_post_ops_,int32_t)( ( post_ops->bias )->bias, n ); \
} \
else {} \
} \
else \