mirror of
https://github.com/amd/blis.git
synced 2026-04-20 15:48:50 +00:00
Bug fixes in bench and pack code for s8 and bf16 datatypes
Details: - Fixed the logic to identify an API that has int4 weights in bench files for gemm and batch_gemm. - Eliminated the memcpy instructions used in pack functions of zen4 kernels and replaced them with masked load instruction. This ensures that the load register will be populated with zeroes at locations where mask is set to zero. Change-Id: I8dd1ea7779c8295b7b4adec82069e80c6493155e AMD-Internal:[SWLCSG-3274]
This commit is contained in:
@@ -553,16 +553,13 @@ void packb_nrlt16_bf16bf16f32of32_row_major
|
||||
|
||||
dim_t kr_new = 0;
|
||||
|
||||
bfloat16 buf0[16];
|
||||
bfloat16 buf1[16];
|
||||
__mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_partial_rem ) );
|
||||
|
||||
for ( int kr = 0; kr < k_full_pieces; kr += 2 )
|
||||
{
|
||||
memcpy( buf0, ( b + ( ldb * ( kr + 0 ) ) ), ( n0_partial_rem * sizeof( bfloat16 ) ) );
|
||||
memcpy( buf1, ( b + ( ldb * ( kr + 1 ) ) ), ( n0_partial_rem * sizeof( bfloat16 ) ) );
|
||||
// Rearrange for dpbf16_ps, read 2 rows from B with next 16 elements in each row.
|
||||
a0 = _mm256_maskz_loadu_epi16( 0xFFFF, buf0 );
|
||||
c0 = _mm256_maskz_loadu_epi16( 0xFFFF, buf1 );
|
||||
a0 = _mm256_maskz_loadu_epi16( load_mask, b + ( ldb * ( kr + 0 ) ) );
|
||||
c0 = _mm256_maskz_loadu_epi16( load_mask, b + ( ldb * ( kr + 1 ) ) );
|
||||
|
||||
a01 = _mm256_unpacklo_epi16( a0, c0 );
|
||||
a0 = _mm256_unpackhi_epi16( a0, c0 );
|
||||
@@ -587,8 +584,7 @@ void packb_nrlt16_bf16bf16f32of32_row_major
|
||||
// Handle k remainder.
|
||||
if ( k_partial_pieces > 0 )
|
||||
{
|
||||
memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( bfloat16 ) ) );
|
||||
a0 = _mm256_maskz_loadu_epi16( 0xFFFF, buf0 );
|
||||
a0 = _mm256_maskz_loadu_epi16( load_mask, b + ( ldb * ( k_full_pieces + 0 ) ) );
|
||||
c0 = _mm256_setzero_si256();
|
||||
|
||||
a01 = _mm256_unpacklo_epi16( a0, c0 );
|
||||
|
||||
@@ -1050,11 +1050,6 @@ void packb_nrlt16_s8s8s32os32_row_major
|
||||
)
|
||||
{
|
||||
dim_t NR = 64;
|
||||
int8_t buf0[16];
|
||||
int8_t buf1[16];
|
||||
int8_t buf2[16];
|
||||
int8_t buf3[16];
|
||||
|
||||
dim_t kr_new = 0;
|
||||
|
||||
dim_t k_full_pieces_blks = KC / 4;
|
||||
@@ -1076,18 +1071,15 @@ void packb_nrlt16_s8s8s32os32_row_major
|
||||
//load the temp buffer to compute column sum of B matrix
|
||||
sum1 = _mm512_loadu_si512( pack_b_column_sum );
|
||||
|
||||
__mmask16 load_mask = _cvtu32_mask16( 0xFFFF >> ( 16 - n0_partial_rem ) );
|
||||
|
||||
for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 )
|
||||
{
|
||||
memcpy( buf0, ( b + ( ldb * ( kr + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
|
||||
memcpy( buf1, ( b + ( ldb * ( kr + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
|
||||
memcpy( buf2, ( b + ( ldb * ( kr + 2 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
|
||||
memcpy( buf3, ( b + ( ldb * ( kr + 3 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
|
||||
|
||||
// Rearrange for vpdpbusd, read 4 rows from B with next 16 elements in each row.
|
||||
a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 );
|
||||
b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf1 );
|
||||
c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf2 );
|
||||
d0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf3 );
|
||||
a0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( kr + 0 ) ) );
|
||||
b0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( kr + 1 ) ) );
|
||||
c0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( kr + 2 ) ) );
|
||||
d0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( kr + 3 ) ) );
|
||||
|
||||
//add all the columns : sum = add (sum, a0, b0, c0, d0)
|
||||
sum1 =
|
||||
@@ -1128,13 +1120,9 @@ void packb_nrlt16_s8s8s32os32_row_major
|
||||
{
|
||||
if ( k_partial_pieces == 3 )
|
||||
{
|
||||
memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
|
||||
memcpy( buf1, ( b + ( ldb * ( k_full_pieces + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
|
||||
memcpy( buf2, ( b + ( ldb * ( k_full_pieces + 2 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
|
||||
|
||||
a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 );
|
||||
b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf1 );
|
||||
c0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf2 );
|
||||
a0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( k_full_pieces + 0 ) ) );
|
||||
b0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( k_full_pieces + 1 ) ) );
|
||||
c0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( k_full_pieces + 2 ) ) );
|
||||
d0_16 = _mm_setzero_si128();
|
||||
|
||||
sum1 =
|
||||
@@ -1148,11 +1136,8 @@ void packb_nrlt16_s8s8s32os32_row_major
|
||||
}
|
||||
else if( k_partial_pieces == 2 )
|
||||
{
|
||||
memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
|
||||
memcpy( buf1, ( b + ( ldb * ( k_full_pieces + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
|
||||
|
||||
a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 );
|
||||
b0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf1 );
|
||||
a0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( k_full_pieces + 0 ) ) );
|
||||
b0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( k_full_pieces + 1 ) ) );
|
||||
c0_16 = _mm_setzero_si128();
|
||||
d0_16 = _mm_setzero_si128();
|
||||
|
||||
@@ -1164,9 +1149,7 @@ void packb_nrlt16_s8s8s32os32_row_major
|
||||
}
|
||||
else //k_partial_pieces == 1
|
||||
{
|
||||
memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) );
|
||||
|
||||
a0_16 = _mm_maskz_loadu_epi8( 0xFFFF, buf0 );
|
||||
a0_16 = _mm_maskz_loadu_epi8( load_mask, b + ( ldb * ( k_full_pieces + 0 ) ) );
|
||||
b0_16 = _mm_setzero_si128();
|
||||
c0_16 = _mm_setzero_si128();
|
||||
d0_16 = _mm_setzero_si128();
|
||||
|
||||
Reference in New Issue
Block a user