Bug fixes in bench and pack code for s8 and bf16 datatypes

Details:
- Fixed the logic to identify an API that has int4 weights in
  bench files for gemm and batch_gemm.
- Eliminated the memcpy instructions used in pack functions of
  zen4 kernels and replaced them with masked load instruction.
  This ensures that the load register will be populated with
  zeroes at locations where mask is set to zero.

Change-Id: I8dd1ea7779c8295b7b4adec82069e80c6493155e
AMD-Internal:[SWLCSG-3274]
This commit is contained in:
Meghana Vankadari
2025-02-28 05:10:35 +00:00
parent b4c1026ec2
commit 6c29236166
4 changed files with 22 additions and 299 deletions

View File

@@ -553,16 +553,13 @@ void packb_nrlt16_bf16bf16f32of32_row_major
dim_t kr_new = 0;
bfloat16 buf0[16];
bfloat16 buf1[16];
__mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_partial_rem ) );
for ( int kr = 0; kr < k_full_pieces; kr += 2 )
{
memcpy( buf0, ( b + ( ldb * ( kr + 0 ) ) ), ( n0_partial_rem * sizeof( bfloat16 ) ) );
memcpy( buf1, ( b + ( ldb * ( kr + 1 ) ) ), ( n0_partial_rem * sizeof( bfloat16 ) ) );
// Rearrange for dpbf16_ps, read 2 rows from B with next 16 elements in each row.
a0 = _mm256_maskz_loadu_epi16( 0xFFFF, buf0 );
c0 = _mm256_maskz_loadu_epi16( 0xFFFF, buf1 );
a0 = _mm256_maskz_loadu_epi16( load_mask, b + ( ldb * ( kr + 0 ) ) );
c0 = _mm256_maskz_loadu_epi16( load_mask, b + ( ldb * ( kr + 1 ) ) );
a01 = _mm256_unpacklo_epi16( a0, c0 );
a0 = _mm256_unpackhi_epi16( a0, c0 );
@@ -587,8 +584,7 @@ void packb_nrlt16_bf16bf16f32of32_row_major
// Handle k remainder.
if ( k_partial_pieces > 0 )
{
memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( bfloat16 ) ) );
a0 = _mm256_maskz_loadu_epi16( 0xFFFF, buf0 );
a0 = _mm256_maskz_loadu_epi16( load_mask, b + ( ldb * ( k_full_pieces + 0 ) ) );
c0 = _mm256_setzero_si256();
a01 = _mm256_unpacklo_epi16( a0, c0 );

View File

@@ -1050,11 +1050,6 @@ void packb_nrlt16_s8s8s32os32_row_major
)
{
dim_t NR = 64;
int8_t buf0[16];
int8_t buf1[16];
int8_t buf2[16];
int8_t buf3[16];
dim_t kr_new = 0;
dim_t k_full_pieces_blks = KC / 4;
@@ -1076,18 +1071,15 @@ void packb_nrlt16_s8s8s32os32_row_major
//load the temp buffer to compute column sum of B matrix
sum1 = _mm512_loadu_si512( pack_b_column_sum );
__mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_partial_rem ) );
for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 )
{
memcpy( buf0, ( b + ( ldb * ( kr + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
memcpy( buf1, ( b + ( ldb * ( kr + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
memcpy( buf2, ( b + ( ldb * ( kr + 2 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
memcpy( buf3, ( b + ( ldb * ( kr + 3 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
// Rearrange for vpdpbusd, read 4 rows from B with next 16 elements in each row.
a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 );
b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf1 );
c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf2 );
d0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf3 );
a0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( kr + 0 ) ) );
b0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( kr + 1 ) ) );
c0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( kr + 2 ) ) );
d0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( kr + 3 ) ) );
//add all the columns : sum = add (sum, a0, b0, c0, d0)
sum1 =
@@ -1128,13 +1120,9 @@ void packb_nrlt16_s8s8s32os32_row_major
{
if ( k_partial_pieces == 3 )
{
memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
memcpy( buf1, ( b + ( ldb * ( k_full_pieces + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
memcpy( buf2, ( b + ( ldb * ( k_full_pieces + 2 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 );
b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf1 );
c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf2 );
a0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( k_full_pieces + 0 ) ) );
b0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( k_full_pieces + 1 ) ) );
c0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( k_full_pieces + 2 ) ) );
d0_16 = _mm_setzero_si128();
sum1 =
@@ -1148,11 +1136,8 @@ void packb_nrlt16_s8s8s32os32_row_major
}
else if( k_partial_pieces == 2 )
{
memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
memcpy( buf1, ( b + ( ldb * ( k_full_pieces + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 );
b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf1 );
a0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( k_full_pieces + 0 ) ) );
b0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( k_full_pieces + 1 ) ) );
c0_16 = _mm_setzero_si128();
d0_16 = _mm_setzero_si128();
@@ -1164,9 +1149,7 @@ void packb_nrlt16_s8s8s32os32_row_major
}
else //k_partial_pieces == 1
{
memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 );
a0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( k_full_pieces + 0 ) ) );
b0_16 = _mm_setzero_si128();
c0_16 = _mm_setzero_si128();
d0_16 = _mm_setzero_si128();