Bug fixes in bench and pack code for s8 and bf16 datatypes

Details:
- Fixed the logic to identify an API that has int4 weights in
  bench files for gemm and batch_gemm.
- Eliminated the memcpy instructions used in pack functions of
  zen4 kernels and replaced them with masked load instruction.
  This ensures that the load register will be populated with
  zeroes at locations where mask is set to zero.

Change-Id: I8dd1ea7779c8295b7b4adec82069e80c6493155e
AMD-Internal:[SWLCSG-3274]
This commit is contained in:
Meghana Vankadari
2025-02-28 05:10:35 +00:00
parent b4c1026ec2
commit 6c29236166
4 changed files with 22 additions and 299 deletions

View File

@@ -432,261 +432,6 @@ GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int32_t,float)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int8_t,float)
GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(uint8_t,float)
#if 0
#define GEN_MAT_MUL_ACC_CHK_DRV_FUNC(A_type,B_type,C_type,ACCUM_type,SCALE_type,BLAS_SFX,BLAS_DOWNSCALE_SFX) \
void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
( \
FILE* fout, \
const char* stor_order, \
char* transa, \
char* transb, \
dim_t bs, \
dim_t* m, \
dim_t* n, \
dim_t* k, \
ACCUM_type* alpha, \
A_type** a, \
dim_t* lda, \
B_type** b, \
dim_t* ldb, \
ACCUM_type* beta, \
C_type** c, \
dim_t* ldc, \
C_type** c_ref, \
dim_t* ldc_ref, \
aocl_post_op** post_op, \
bool int4_testing /* Workaround to enable int4 B matrix testing. */ \
) \
{ \
dim_t rs_a, cs_a; \
dim_t rs_b, cs_b; \
dim_t rs_c, cs_c; \
dim_t rs_c_ref; \
dim_t cs_c_ref; \
for( dim_t bs_i = 0; bs_i < bs; bs_i++ ) \
{ \
if( stor_order[bs_i] == 'r' || stor_order[bs_i] == 'R' ) \
{ \
if( ( transa[bs_i] == 'n' ) || ( transa[bs_i] == 'N' ) ) \
{ \
rs_a = lda[bs_i]; \
cs_a = 1; \
} \
else \
{ \
rs_a = 1; \
cs_a = lda[bs_i]; \
} \
if( ( transb[bs_i] == 'n' ) || ( transb[bs_i] == 'N' ) ) \
{ \
rs_b = ldb[bs_i]; \
cs_b = 1; \
} \
else \
{ \
rs_b = 1; \
cs_b = ldb[bs_i]; \
} \
rs_c = ldc[bs_i]; \
cs_c = 1; \
rs_c_ref = ldc_ref[bs_i]; \
cs_c_ref = 1; \
} \
else /* column storage */ \
{ \
if( transa[bs_i] == 'n' || transa[bs_i] == 'N') \
{ \
rs_a = 1; \
cs_a = lda[bs_i]; \
} \
else \
{ \
rs_a= lda[bs_i]; \
cs_a = 1; \
} \
if( ( transb[bs_i] == 'n' ) || ( transb[bs_i] == 'N' ) ) \
{ \
rs_b = 1; \
cs_b = ldb[bs_i]; \
} \
else \
{ \
rs_b = ldb[bs_i]; \
cs_b = 1; \
} \
rs_c = 1; \
cs_c = ldc[bs_i]; \
rs_c_ref = 1; \
cs_c_ref = ldc_ref[bs_i]; \
} \
aocl_pre_op* a_pre_op = NULL; \
if ( post_op[bs_i] != NULL ) \
{ \
a_pre_op = post_op[bs_i]->pre_ops; \
} \
for ( dim_t i = 0; i < m[bs_i]; ++i ) \
{ \
for ( dim_t j = 0; j < n[bs_i]; ++j ) \
{ \
ACCUM_type temp_accum = 0; \
C_type out_temp_accum = 0; \
\
temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_accum_,BLAS_SFX) \
(a[bs_i], b[bs_i], c_ref[bs_i], temp_accum, alpha[bs_i], beta[bs_i],\
rs_a, rs_b, cs_a, cs_b, rs_c_ref, cs_c_ref, i, j, k[bs_i], \
int4_testing, a_pre_op); \
\
if ( post_op[bs_i] != NULL ) \
{ \
dim_t ele_i = 0; \
for ( dim_t op_id = 0; op_id < post_op[bs_i]->seq_length; ++op_id ) \
{ \
if ( post_op[bs_i]->seq_vector[op_id] == BIAS ) \
{ \
temp_accum += GEN_FUNC_NAME(get_bias_post_op_val_,BLAS_SFX) \
( ( post_op[bs_i]->bias )->bias, j, ( post_op[bs_i]->bias )->stor_type ); \
} \
else if ( post_op[bs_i]->seq_vector[op_id] == ELTWISE ) \
{ \
if ( ( post_op[bs_i]->eltwise + ele_i )->algo.algo_type == \
PRELU ) /* PReLU*/ \
{ \
temp_accum = ( temp_accum > 0 ) ? \
temp_accum : \
( temp_accum * \
*( ( ACCUM_type* ) ( post_op[bs_i]->eltwise + ele_i )->algo.alpha ) ); \
ele_i += 1; \
} \
else if ( ( post_op[bs_i]->eltwise + ele_i )->algo.algo_type == \
GELU_TANH ) /* TANH GeLU*/ \
{ \
temp_accum = GEN_FUNC_NAME(GELU_TANH_post_op_,BLAS_SFX) (temp_accum);\
ele_i += 1; \
} \
else if ( ( post_op[bs_i]->eltwise + ele_i )->algo.algo_type == \
GELU_ERF ) /* ERF GeLU*/ \
{ \
temp_accum = GEN_FUNC_NAME(GELU_ERF_post_op_,BLAS_SFX) (temp_accum);\
ele_i += 1; \
} \
else if ( ( post_op[bs_i]->eltwise + ele_i )->algo.algo_type == \
SWISH ) /* SiLU*/ \
{ \
temp_accum = GEN_FUNC_NAME(SWISH_post_op_,BLAS_SFX) \
(temp_accum, \
( post_op[bs_i]->eltwise + ele_i )->algo.alpha );\
ele_i += 1; \
} \
else if ( ( post_op[bs_i]->eltwise + ele_i )->algo.algo_type == \
RELU ) /* ReLU*/ \
{ \
temp_accum = ( temp_accum > 0 ) ? temp_accum : 0 ; \
ele_i += 1; \
} \
else if ( ( post_op[bs_i]->eltwise + ele_i )->algo.algo_type == \
TANH ) /* TANH*/ \
{ \
temp_accum = GEN_FUNC_NAME(TANH_post_op_,BLAS_SFX) (temp_accum);\
ele_i += 1; \
} \
else if ( ( post_op[bs_i]->eltwise + ele_i )->algo.algo_type == \
SIGMOID ) /* Sigmoid*/ \
{ \
temp_accum = GEN_FUNC_NAME(SIGMOID_post_op_,BLAS_SFX) (temp_accum);\
ele_i += 1; \
} \
else if ( ( post_op[bs_i]->eltwise + ele_i )->algo.algo_type == \
CLIP ) /* CLIP*/ \
{ \
temp_accum = \
min \
( \
max \
( \
temp_accum, \
*( ( ACCUM_type* ) \
( post_op[bs_i]->eltwise + ele_i )->algo.alpha ) \
), \
*( ( ACCUM_type* ) \
( post_op[bs_i]->eltwise + ele_i )->algo.beta) \
); \
ele_i += 1; \
} \
else \
{} \
} \
else if ( post_op[bs_i]->seq_vector[op_id] == SCALE ) \
{ \
temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_DOWNSCALE_SFX) \
(temp_accum, post_op[bs_i], j); \
} \
else if ( post_op[bs_i]->seq_vector[op_id] == MATRIX_ADD ) \
{ \
dim_t rs_m = ( post_op[bs_i]->matrix_add )->ldm; \
dim_t cs_m = 1; \
if ( ( stor_order[bs_i] == 'C' ) || ( stor_order[bs_i] == 'c' ) ) \
{ \
cs_m = rs_m; \
rs_m = 1; \
} \
float* scl_fctr = ( float* )( ( post_op[bs_i]->matrix_add )->scale_factor ); \
dim_t scl_fctr_len = ( post_op[bs_i]->matrix_add )->scale_factor_len; \
temp_accum += GEN_FUNC_NAME(get_matrix_add_post_op_val_,BLAS_SFX) \
( ( post_op[bs_i]->matrix_add )->matrix, i, \
j, rs_m, cs_m, scl_fctr, scl_fctr_len, ( post_op[bs_i]->matrix_add)->stor_type ); \
} \
else if ( post_op[bs_i]->seq_vector[op_id] == MATRIX_MUL ) \
{ \
dim_t rs_m = ( post_op[bs_i]->matrix_mul )->ldm; \
dim_t cs_m = 1; \
if ( ( stor_order[bs_i] == 'C' ) || ( stor_order[bs_i] == 'c' ) ) \
{ \
cs_m = rs_m; \
rs_m = 1; \
} \
float* scl_fctr = ( float* )( ( post_op[bs_i]->matrix_mul )->scale_factor ); \
dim_t scl_fctr_len = ( post_op[bs_i]->matrix_mul )->scale_factor_len; \
temp_accum *= GEN_FUNC_NAME(get_matrix_mul_post_op_val_,BLAS_SFX) \
(( post_op[bs_i]->matrix_mul )->matrix, i, \
j, rs_m, cs_m, scl_fctr, scl_fctr_len, ( post_op[bs_i]->matrix_mul)->stor_type ); \
} \
else \
{} \
} \
} \
/* Need to convert to downscaled type if required.*/ \
mat_mul_get_output_type_val ## ACCUM_type ## C_type \
( \
&out_temp_accum, &temp_accum \
); \
\
float comp_float, ref_float; \
GEN_FUNC_NAME(C_type,_to_float)(*( c[bs_i] + ( rs_c * i ) + ( cs_c * j ) ), &comp_float); \
GEN_FUNC_NAME(C_type,_to_float)(out_temp_accum, &ref_float); \
\
if ( ( ( comp_float - ref_float ) > 1.0E-5 ) || \
( ( ref_float - comp_float ) > 1.0E-5 ) ) \
{ \
if ( fout ) \
{ \
fprintf( fout, "%s Failure input gemm:%ld, m: %ld, n: %ld, k: %ld," \
" lda: %ld, ldb: %ld, ldc: %ld, computed:%f, ref:%f, diff:%f\n", \
XSTR(BLAS_SFX), bs_i, m[bs_i], n[bs_i], k[bs_i], lda[bs_i], ldb[bs_i], ldc[bs_i], comp_float, \
ref_float, comp_float - ref_float); \
fflush( fout ); \
} \
printf("failure, gemm:%ld, m_index: %ld, n_index: %ld, k: %ld, computed:%f, ref:%f," \
"diff:%f\n", bs_i, i, j, k[bs_i], comp_float, ref_float, comp_float-ref_float); \
goto cleanup_acc; \
} \
} \
} \
} \
cleanup_acc: \
return; \
} \
#else
#define GEN_MAT_MUL_ACC_CHK_DRV_FUNC(A_type,B_type,C_type,ACCUM_type,POST_ACCUM_type,SCALE_type,BLAS_SFX,BLAS_DOWNSCALE_SFX) \
void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
( \
@@ -941,7 +686,6 @@ cleanup_acc: \
return; \
} \
#endif
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int32_t,int32_t,float,float,u8s8s32os32,u8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int32_t,float,float,u8s8s32os8,u8s8s32os8)
GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,uint8_t,int32_t,float,float,u8s8s32ou8,u8s8s32ou8)
@@ -1021,7 +765,7 @@ void mat_mul_bench_main_ ## BLAS_SFX \
Sum_type alpha[bs]; \
Sum_type beta[bs]; \
aocl_post_op** post_op = (aocl_post_op**)lpgemm_malloc(sizeof(aocl_post_op*) * bs); \
bool int4_testing = ( strcmp(#BLAS_SFX,"bf16s4f32of32") | strcmp(#BLAS_SFX,"bf16s4f32obf16") ); \
bool int4_testing = ( ( strcmp(#BLAS_SFX,"bf16s4f32of32") == 0 ) || ( strcmp(#BLAS_SFX,"bf16s4f32obf16") == 0 ) ); \
for( dim_t i = 0; i < bs; i++ ) \
{ \
dim_t size_A = 0; \
@@ -1042,7 +786,7 @@ void mat_mul_bench_main_ ## BLAS_SFX \
a[i] = ( A_type* ) lpgemm_malloc( sizeof( A_type ) * size_A ); \
GEN_FUNC_NAME(fill_array_,A_type)(a[i], size_A ); \
b[i] = ( B_type* ) lpgemm_malloc( sizeof( B_type ) * size_B ); \
if ( int4_testing != FALSE ) \
if ( int4_testing == FALSE ) \
{ \
GEN_FUNC_NAME(fill_array_,B_type)(b[i], size_B ); \
} \
@@ -1095,7 +839,7 @@ void mat_mul_bench_main_ ## BLAS_SFX \
} \
else if ( ( op_b[i] == 'r' ) || ( op_b[i] == 'R' ) ) \
{ \
if ( int4_testing != FALSE ) \
if ( int4_testing == FALSE ) \
{ \
siz_t b_reorder_buf_siz_req = \
GEN_FUNC_NAME(aocl_get_reorder_buf_size_,REORDER_SFX)( stor_order[i], transb[i], 'B', k[i], n[i] ); \

View File

@@ -756,10 +756,10 @@ void mat_mul_bench_main_ ## BLAS_SFX \
} \
A_type* a = ( A_type* ) lpgemm_malloc( sizeof( A_type ) * size_A ); \
GEN_FUNC_NAME(fill_array_,A_type)(a, size_A ); \
bool int4_testing = ( strcmp(#BLAS_SFX,"bf16s4f32of32") | strcmp(#BLAS_SFX,"bf16s4f32obf16") ); \
bool int4_testing = ( ( strcmp(#BLAS_SFX,"bf16s4f32of32") == 0 ) || (strcmp(#BLAS_SFX,"bf16s4f32obf16") == 0 ) ); \
\
B_type* b = ( B_type* ) lpgemm_malloc( sizeof( B_type ) * size_B ); \
if ( int4_testing != FALSE ) \
if ( int4_testing == FALSE ) \
{ \
GEN_FUNC_NAME(fill_array_,B_type)(b, size_B ); \
} \
@@ -829,7 +829,7 @@ void mat_mul_bench_main_ ## BLAS_SFX \
{ \
B_type* b_reorder = NULL; \
/* Reorder B.*/ \
if ( int4_testing != FALSE ) \
if ( int4_testing == FALSE ) \
{ \
siz_t b_reorder_buf_siz_req = \
GEN_FUNC_NAME(aocl_get_reorder_buf_size_,REORDER_SFX)( stor_order, transb, 'B', k, n ); \

View File

@@ -553,16 +553,13 @@ void packb_nrlt16_bf16bf16f32of32_row_major
dim_t kr_new = 0;
bfloat16 buf0[16];
bfloat16 buf1[16];
__mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_partial_rem ) );
for ( int kr = 0; kr < k_full_pieces; kr += 2 )
{
memcpy( buf0, ( b + ( ldb * ( kr + 0 ) ) ), ( n0_partial_rem * sizeof( bfloat16 ) ) );
memcpy( buf1, ( b + ( ldb * ( kr + 1 ) ) ), ( n0_partial_rem * sizeof( bfloat16 ) ) );
// Rearrange for dpbf16_ps, read 2 rows from B with next 16 elements in each row.
a0 = _mm256_maskz_loadu_epi16( 0xFFFF, buf0 );
c0 = _mm256_maskz_loadu_epi16( 0xFFFF, buf1 );
a0 = _mm256_maskz_loadu_epi16( load_mask, b + ( ldb * ( kr + 0 ) ) );
c0 = _mm256_maskz_loadu_epi16( load_mask, b + ( ldb * ( kr + 1 ) ) );
a01 = _mm256_unpacklo_epi16( a0, c0 );
a0 = _mm256_unpackhi_epi16( a0, c0 );
@@ -587,8 +584,7 @@ void packb_nrlt16_bf16bf16f32of32_row_major
// Handle k remainder.
if ( k_partial_pieces > 0 )
{
memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( bfloat16 ) ) );
a0 = _mm256_maskz_loadu_epi16( 0xFFFF, buf0 );
a0 = _mm256_maskz_loadu_epi16( load_mask, b + ( ldb * ( k_full_pieces + 0 ) ) );
c0 = _mm256_setzero_si256();
a01 = _mm256_unpacklo_epi16( a0, c0 );

View File

@@ -1050,11 +1050,6 @@ void packb_nrlt16_s8s8s32os32_row_major
)
{
dim_t NR = 64;
int8_t buf0[16];
int8_t buf1[16];
int8_t buf2[16];
int8_t buf3[16];
dim_t kr_new = 0;
dim_t k_full_pieces_blks = KC / 4;
@@ -1076,18 +1071,15 @@ void packb_nrlt16_s8s8s32os32_row_major
//load the temp buffer to compute column sum of B matrix
sum1 = _mm512_loadu_si512( pack_b_column_sum );
__mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_partial_rem ) );
for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 )
{
memcpy( buf0, ( b + ( ldb * ( kr + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
memcpy( buf1, ( b + ( ldb * ( kr + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
memcpy( buf2, ( b + ( ldb * ( kr + 2 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
memcpy( buf3, ( b + ( ldb * ( kr + 3 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
// Rearrange for vpdpbusd, read 4 rows from B with next 16 elements in each row.
a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 );
b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf1 );
c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf2 );
d0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf3 );
a0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( kr + 0 ) ) );
b0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( kr + 1 ) ) );
c0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( kr + 2 ) ) );
d0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( kr + 3 ) ) );
//add all the columns : sum = add (sum, a0, b0, c0, d0)
sum1 =
@@ -1128,13 +1120,9 @@ void packb_nrlt16_s8s8s32os32_row_major
{
if ( k_partial_pieces == 3 )
{
memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
memcpy( buf1, ( b + ( ldb * ( k_full_pieces + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
memcpy( buf2, ( b + ( ldb * ( k_full_pieces + 2 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 );
b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf1 );
c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf2 );
a0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( k_full_pieces + 0 ) ) );
b0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( k_full_pieces + 1 ) ) );
c0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( k_full_pieces + 2 ) ) );
d0_16 = _mm_setzero_si128();
sum1 =
@@ -1148,11 +1136,8 @@ void packb_nrlt16_s8s8s32os32_row_major
}
else if( k_partial_pieces == 2 )
{
memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
memcpy( buf1, ( b + ( ldb * ( k_full_pieces + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 );
b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf1 );
a0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( k_full_pieces + 0 ) ) );
b0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( k_full_pieces + 1 ) ) );
c0_16 = _mm_setzero_si128();
d0_16 = _mm_setzero_si128();
@@ -1164,9 +1149,7 @@ void packb_nrlt16_s8s8s32os32_row_major
}
else //k_partial_pieces == 1
{
memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 );
a0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( k_full_pieces + 0 ) ) );
b0_16 = _mm_setzero_si128();
c0_16 = _mm_setzero_si128();
d0_16 = _mm_setzero_si128();