Bugfix for A matrix packing in int8(S8/U8) APIs for Batch-Matmul

- A matrix by default isn't expected to be packed for a normal row-stored
 case. Hence the packing implementation is incomplete.
 - But if the user explicitly enables packing, interface wasn't handling
 the condition appropriately leading to data overwriting inside the incomplete
 pack kernels, thereby leading to accuracy failure.
 - As a fix, updated the interface to set the explicit PACK A to UNPACKED and
 proceed with GEMM in cases where transpose of A is not necessary.
 - Updated the batch gemm input file with additional test cases covering all the
 APIs.
Bug Fixes:
 - Fixed implementation logic for column major inputs with post-ops to be disabled
 in S8 batch mat-mul. With the existing implementation, column-major inputs wouldn't
 be executed in case of of32/os32 inputs.
 - Fixed the Scale/ZP calculation in bench foru8s8s32ou8 condition, which was leading
 to accuracy failures.

[AMD-Internal: CPUPL-7283 ]
This commit is contained in:
V, Varsha
2025-08-26 16:46:37 +05:30
committed by GitHub
parent deafc527fc
commit 3df4aac2d2
5 changed files with 370 additions and 266 deletions

View File

@@ -1,6 +1,8 @@
f32f32f32of32:group_count=1
f32f32f32of32:group_count=2
group_size=4
r n n n n 6 64 128 128 64 64 bias=bf16,relu,swish
group_size=3
r t t n n 92 1479 589 92 589 1479 scale=vector,zp=vector,bias=na,clip
r n t n n 78 9810 1229 1229 9810 9810 matrix_add=bf16,matrix_mul=f32
s8s8s32obf16:group_count=1
group_size=5
r n n n r 67 21 1823 1823 21 21 scale=vector,zp=scalar,relu,clip
@@ -21,4 +23,7 @@ group_size=6
r n n n r 17 2714 468 468 2714 2714 scale=vector,zp=vector,bias=na
s8s8s32obf16:group_count=1
group_size=4
r n n n n 43 2240 1553 1553 2240 2240 scale=vector,zp=scalar,relu,clip
r n n n n 43 2240 1553 1553 2240 2240 scale=vector,zp=scalar,relu,clip
*:group_count=1
group_size=3
r t t n n 92 1479 589 92 589 1479 scale=vector,zp=vector,bias=na,clip

View File

@@ -634,7 +634,7 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
(post_temp_accum, post_op[gc_i], j, ( post_op[gc_i]->sum )->sf_stor_type, \
( post_op[gc_i]->sum )->zp_stor_type); \
} \
else if ( post_op[mat_idx + gs_i]->seq_vector[op_id] == MATRIX_ADD ) \
else if ( post_op[gc_i]->seq_vector[op_id] == MATRIX_ADD ) \
{ \
dim_t rs_m = ( post_op[gc_i]->matrix_add )->ldm; \
dim_t cs_m = 1; \

View File

@@ -2238,7 +2238,7 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
DSCALE_type* temp_dscale_ptr = ( DSCALE_type* )( post_ops->sum )->scale_factor; \
GEN_FUNC_NAME(fill_array_,DSCALE_type)(temp_dscale_ptr, n_scale); \
( post_ops->sum )->scale_factor_len = n_scale; \
if(strcmp(#BLAS_SFX, "u8s8s32ou8")) for(dim_t i=0;i<n_scale;i++) temp_dscale_ptr[i] = abs(temp_dscale_ptr[i]);\
if(!strcmp(#BLAS_SFX, "u8s8s32ou8")) for(dim_t i=0;i<n_scale;i++) temp_dscale_ptr[i] = abs(temp_dscale_ptr[i]);\
} \
\
if(is_zp_stor_type == TRUE) \
@@ -2311,7 +2311,7 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
C_DSCALE_type* temp_dzero_point_ptr = ( C_DSCALE_type* )( post_ops->sum )->zero_point; \
GEN_FUNC_NAME(fill_array_,C_DSCALE_type)(temp_dzero_point_ptr, n_zp); \
( post_ops->sum )->zero_point_len = n_zp; \
if(strcmp(#BLAS_SFX, "u8s8s32ou8")) for(dim_t i=0;i<n_zp;i++) temp_dzero_point_ptr[i] = abs(temp_dzero_point_ptr[i]);\
if(!strcmp(#BLAS_SFX, "u8s8s32ou8")) for(dim_t i=0;i<n_zp;i++) temp_dzero_point_ptr[i] = abs(temp_dzero_point_ptr[i]);\
} \
} \
\