Implemented reference unreorder bf16 function

Description:

Implemented a c reference for
aocl_gemm_unreorder_bf16bf16f32of32 function

The implementation working for row major and
column major yet to be enabled.

AMD-Internal: [ SWLCSG-3279 ]

Change-Id: Ibcce4180bb897a40252140012d8d6886c38cb77a
This commit is contained in:
Nallani Bhaskar
2025-02-06 10:17:40 +00:00
parent ef04388a44
commit 0acb5eb9a4
6 changed files with 1001 additions and 33 deletions

View File

@@ -41,7 +41,7 @@
void unpackb_nr48_bf16bf16f32of32_row_major
(
const bfloat16* b,
bfloat16* unpack_b_buffer_bf16bf16f32of32,
bfloat16* unpack_b_buffer,
const dim_t KC,
dim_t ldb
)
@@ -73,8 +73,8 @@ void unpackb_nr48_bf16bf16f32of32_row_major
a01 = _mm512_permutex2var_epi16( b0, selector_even, a0 );
b0 = _mm512_permutex2var_epi16( b0, selector_odd, a0 );
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ), a01 );
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ), b0 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 0 ) ), a01 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 1 ) ), b0 );
c0 = _mm512_loadu_si512( b + ( ( kr_new + 2 ) * NR1 ) );
d0 = _mm512_setzero_si512();
@@ -82,8 +82,8 @@ void unpackb_nr48_bf16bf16f32of32_row_major
c01 = _mm512_permutex2var_epi16( d0, selector_even, c0 );
d0 = _mm512_permutex2var_epi16( d0, selector_odd, c0 );
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ) + NR1, 0xFFFF, c01 );
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ) + NR1, 0xFFFF, d0 );
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 0 ) ) + NR1, 0xFFFF, c01 );
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 1 ) ) + NR1, 0xFFFF, d0 );
kr_new += 3;
}
@@ -96,18 +96,18 @@ void unpackb_nr48_bf16bf16f32of32_row_major
a01 = _mm512_permutex2var_epi16( b0, selector_even, a0 );
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ), a01 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ), a01 );
c0 = _mm512_loadu_si512( b + ( ( kr_new + 2 ) * NR1 ) );
c01 = _mm512_permutex2var_epi16( c0, selector_even, c0 );
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ) + NR1, 0xFFFF, c01 );
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ) + NR1, 0xFFFF, c01 );
}
}
void unpackb_nr32_bf16bf16f32of32_row_major
(
const bfloat16* b,
bfloat16* unpack_b_buffer_bf16bf16f32of32,
bfloat16* unpack_b_buffer,
const dim_t KC,
dim_t ldb
)
@@ -138,8 +138,8 @@ void unpackb_nr32_bf16bf16f32of32_row_major
c0 = _mm512_permutex2var_epi16( c0, selector_odd, a0 );
// Store to unpack buffer
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ), a01 );
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ), c0 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 0 ) ), a01 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 1 ) ), c0 );
}
if( k_partial_pieces > 0 )
@@ -150,13 +150,13 @@ void unpackb_nr32_bf16bf16f32of32_row_major
a0 = _mm512_permutex2var_epi16( c0, selector_even, a0 );
// Store to unpack buffer
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ), a0 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ), a0 );
}
}
void unpackb_nr16_bf16bf16f32of32_row_major
(
const bfloat16* b,
bfloat16* unpack_b_buffer_bf16bf16f32of32,
bfloat16* unpack_b_buffer,
const dim_t KC,
dim_t ldb
)
@@ -187,8 +187,8 @@ void unpackb_nr16_bf16bf16f32of32_row_major
c0 = _mm512_permutex2var_epi16( a0, selector_odd, a0 );
// Store to unpack buffer
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ), 0xFFFF, a01 );
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ), 0xFFFF, c0 );
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 0 ) ), 0xFFFF, a01 );
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 1 ) ), 0xFFFF, c0 );
}
if( k_partial_pieces > 0 )
{
@@ -197,13 +197,13 @@ void unpackb_nr16_bf16bf16f32of32_row_major
a0 = _mm512_permutex2var_epi16( a0, selector_even, a0 );
// Store to unpack buffer
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ), 0xFFFF, a0 );
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ), 0xFFFF, a0 );
}
}
void unpackb_nrlt16_bf16bf16f32of32_row_major
(
const bfloat16* b,
bfloat16* unpack_b_buffer_bf16bf16f32of32,
bfloat16* unpack_b_buffer,
const dim_t KC,
dim_t ldb,
dim_t n0_partial_rem
@@ -237,8 +237,8 @@ void unpackb_nrlt16_bf16bf16f32of32_row_major
c0 = _mm512_permutex2var_epi16( a0, selector_odd, a0 );
// Store to unpack buffer
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ), store_mask, a01 );
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ), store_mask, c0 );
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 0 ) ), store_mask, a01 );
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 1 ) ), store_mask, c0 );
}
if( k_partial_pieces > 0 )
{
@@ -247,14 +247,14 @@ void unpackb_nrlt16_bf16bf16f32of32_row_major
a0 = _mm512_permutex2var_epi16( a0, selector_even, a0 );
// Store to unpack buffer
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ), store_mask, a0 );
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ), store_mask, a0 );
}
}
void unpackb_nr64_bf16bf16f32of32_row_major
(
const bfloat16* b,
bfloat16* unpack_b_buffer_bf16bf16f32of32,
bfloat16* unpack_b_buffer,
const dim_t NC,
const dim_t KC,
dim_t ldb
@@ -304,10 +304,10 @@ void unpackb_nr64_bf16bf16f32of32_row_major
d0 = _mm512_permutex2var_epi16( d0, selector_odd, c0 );
// Store to unpack buffer
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ) + jc, a01 );
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ) + jc + 32, c01 );
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ) + jc, b0 );
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ) + jc + 32, d0 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 0 ) ) + jc, a01 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 0 ) ) + jc + 32, c01 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 1 ) ) + jc, b0 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 1 ) ) + jc + 32, d0 );
}
if( k_partial_pieces > 0 )
@@ -322,8 +322,8 @@ void unpackb_nr64_bf16bf16f32of32_row_major
c01 = _mm512_permutex2var_epi16( d0, selector_even, c0 );
// Store to unpack buffer
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ) + jc, a01 );
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ) + jc + 32, c01 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ) + jc, a01 );
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ) + jc + 32, c01 );
}
}
@@ -344,7 +344,7 @@ void unpackb_nr64_bf16bf16f32of32_row_major
unpackb_nr48_bf16bf16f32of32_row_major
(
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
( unpack_b_buffer_bf16bf16f32of32 + n_full_pieces_loop_limit ), KC, ldb
( unpack_b_buffer + n_full_pieces_loop_limit ), KC, ldb
);
n0_partial_unpack = 48;
@@ -354,7 +354,7 @@ void unpackb_nr64_bf16bf16f32of32_row_major
unpackb_nr32_bf16bf16f32of32_row_major
(
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
( unpack_b_buffer_bf16bf16f32of32 + n_full_pieces_loop_limit ), KC, ldb
( unpack_b_buffer + n_full_pieces_loop_limit ), KC, ldb
);
n0_partial_unpack = 32;
@@ -364,7 +364,7 @@ void unpackb_nr64_bf16bf16f32of32_row_major
unpackb_nr16_bf16bf16f32of32_row_major
(
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
( unpack_b_buffer_bf16bf16f32of32 + n_full_pieces_loop_limit ), KC, ldb
( unpack_b_buffer + n_full_pieces_loop_limit ), KC, ldb
);
n0_partial_unpack = 16;
@@ -376,7 +376,7 @@ void unpackb_nr64_bf16bf16f32of32_row_major
(
( b + ( n_full_pieces_loop_limit * KC_updated ) +
( n0_partial_unpack * KC_updated ) ),
( unpack_b_buffer_bf16bf16f32of32 + n_full_pieces_loop_limit + n0_partial_unpack ), KC, ldb,
( unpack_b_buffer + n_full_pieces_loop_limit + n0_partial_unpack ), KC, ldb,
n0_partial_rem
);
}
@@ -895,7 +895,7 @@ void unpackb_nr64_bf16bf16f32of32_col_major
void unpackb_nr64_bf16bf16f32of32
(
const bfloat16* b,
bfloat16* unpack_b_buffer_bf16bf16f32of32,
bfloat16* unpack_b_buffer,
const dim_t NC,
const dim_t KC,
dim_t rs_b,
@@ -904,11 +904,11 @@ void unpackb_nr64_bf16bf16f32of32
{
if( cs_b == 1 )
{
unpackb_nr64_bf16bf16f32of32_row_major( b, unpack_b_buffer_bf16bf16f32of32, NC, KC, rs_b );
unpackb_nr64_bf16bf16f32of32_row_major( b, unpack_b_buffer, NC, KC, rs_b );
}
else
{
unpackb_nr64_bf16bf16f32of32_col_major( b, unpack_b_buffer_bf16bf16f32of32, NC, KC, cs_b );
unpackb_nr64_bf16bf16f32of32_col_major( b, unpack_b_buffer, NC, KC, cs_b );
}
}
#endif // BLIS_ADDON_LPGEMM