mirror of
https://github.com/amd/blis.git
synced 2026-04-20 15:48:50 +00:00
Bugfix for A matrix packing in int8(S8/U8) APIs for Batch-Matmul
- A matrix by default isn't expected to be packed for a normal row-stored case. Hence the packing implementation is incomplete. - But if the user explicitly enables packing, interface wasn't handling the condition appropriately leading to data overwriting inside the incomplete pack kernels, thereby leading to accuracy failure. - As a fix, updated the interface to set the explicit PACK A to UNPACKED and proceed with GEMM in cases where transpose of A is not necessary. - Updated the batch gemm input file with additional test cases covering all the APIs. Bug Fixes: - Fixed implementation logic for column major inputs with post-ops to be disabled in S8 batch mat-mul. With the existing implementation, column-major inputs wouldn't be executed in case of of32/os32 inputs. - Fixed the Scale/ZP calculation in bench foru8s8s32ou8 condition, which was leading to accuracy failures. [AMD-Internal: CPUPL-7283 ]
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
f32f32f32of32:group_count=1
|
||||
f32f32f32of32:group_count=2
|
||||
group_size=4
|
||||
r n n n n 6 64 128 128 64 64 bias=bf16,relu,swish
|
||||
group_size=3
|
||||
r t t n n 92 1479 589 92 589 1479 scale=vector,zp=vector,bias=na,clip
|
||||
r n t n n 78 9810 1229 1229 9810 9810 matrix_add=bf16,matrix_mul=f32
|
||||
s8s8s32obf16:group_count=1
|
||||
group_size=5
|
||||
r n n n r 67 21 1823 1823 21 21 scale=vector,zp=scalar,relu,clip
|
||||
@@ -21,4 +23,7 @@ group_size=6
|
||||
r n n n r 17 2714 468 468 2714 2714 scale=vector,zp=vector,bias=na
|
||||
s8s8s32obf16:group_count=1
|
||||
group_size=4
|
||||
r n n n n 43 2240 1553 1553 2240 2240 scale=vector,zp=scalar,relu,clip
|
||||
r n n n n 43 2240 1553 1553 2240 2240 scale=vector,zp=scalar,relu,clip
|
||||
*:group_count=1
|
||||
group_size=3
|
||||
r t t n n 92 1479 589 92 589 1479 scale=vector,zp=vector,bias=na,clip
|
||||
@@ -634,7 +634,7 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \
|
||||
(post_temp_accum, post_op[gc_i], j, ( post_op[gc_i]->sum )->sf_stor_type, \
|
||||
( post_op[gc_i]->sum )->zp_stor_type); \
|
||||
} \
|
||||
else if ( post_op[mat_idx + gs_i]->seq_vector[op_id] == MATRIX_ADD ) \
|
||||
else if ( post_op[gc_i]->seq_vector[op_id] == MATRIX_ADD ) \
|
||||
{ \
|
||||
dim_t rs_m = ( post_op[gc_i]->matrix_add )->ldm; \
|
||||
dim_t cs_m = 1; \
|
||||
|
||||
@@ -2238,7 +2238,7 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
DSCALE_type* temp_dscale_ptr = ( DSCALE_type* )( post_ops->sum )->scale_factor; \
|
||||
GEN_FUNC_NAME(fill_array_,DSCALE_type)(temp_dscale_ptr, n_scale); \
|
||||
( post_ops->sum )->scale_factor_len = n_scale; \
|
||||
if(strcmp(#BLAS_SFX, "u8s8s32ou8")) for(dim_t i=0;i<n_scale;i++) temp_dscale_ptr[i] = abs(temp_dscale_ptr[i]);\
|
||||
if(!strcmp(#BLAS_SFX, "u8s8s32ou8")) for(dim_t i=0;i<n_scale;i++) temp_dscale_ptr[i] = abs(temp_dscale_ptr[i]);\
|
||||
} \
|
||||
\
|
||||
if(is_zp_stor_type == TRUE) \
|
||||
@@ -2311,7 +2311,7 @@ static inline aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \
|
||||
C_DSCALE_type* temp_dzero_point_ptr = ( C_DSCALE_type* )( post_ops->sum )->zero_point; \
|
||||
GEN_FUNC_NAME(fill_array_,C_DSCALE_type)(temp_dzero_point_ptr, n_zp); \
|
||||
( post_ops->sum )->zero_point_len = n_zp; \
|
||||
if(strcmp(#BLAS_SFX, "u8s8s32ou8")) for(dim_t i=0;i<n_zp;i++) temp_dzero_point_ptr[i] = abs(temp_dzero_point_ptr[i]);\
|
||||
if(!strcmp(#BLAS_SFX, "u8s8s32ou8")) for(dim_t i=0;i<n_zp;i++) temp_dzero_point_ptr[i] = abs(temp_dzero_point_ptr[i]);\
|
||||
} \
|
||||
} \
|
||||
\
|
||||
|
||||
Reference in New Issue
Block a user