mirror of
https://github.com/amd/blis.git
synced 2026-04-20 07:38:53 +00:00
Implemented reference unreorder bf16 function
Description: Implemented a c reference for aocl_gemm_unreorder_bf16bf16f32of32 function The implementation working for row major and column major yet to be enabled. AMD-Internal: [ SWLCSG-3279 ] Change-Id: Ibcce4180bb897a40252140012d8d6886c38cb77a
This commit is contained in:
@@ -41,7 +41,7 @@
|
||||
void unpackb_nr48_bf16bf16f32of32_row_major
|
||||
(
|
||||
const bfloat16* b,
|
||||
bfloat16* unpack_b_buffer_bf16bf16f32of32,
|
||||
bfloat16* unpack_b_buffer,
|
||||
const dim_t KC,
|
||||
dim_t ldb
|
||||
)
|
||||
@@ -73,8 +73,8 @@ void unpackb_nr48_bf16bf16f32of32_row_major
|
||||
a01 = _mm512_permutex2var_epi16( b0, selector_even, a0 );
|
||||
b0 = _mm512_permutex2var_epi16( b0, selector_odd, a0 );
|
||||
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ), a01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ), b0 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 0 ) ), a01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 1 ) ), b0 );
|
||||
|
||||
c0 = _mm512_loadu_si512( b + ( ( kr_new + 2 ) * NR1 ) );
|
||||
d0 = _mm512_setzero_si512();
|
||||
@@ -82,8 +82,8 @@ void unpackb_nr48_bf16bf16f32of32_row_major
|
||||
c01 = _mm512_permutex2var_epi16( d0, selector_even, c0 );
|
||||
d0 = _mm512_permutex2var_epi16( d0, selector_odd, c0 );
|
||||
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ) + NR1, 0xFFFF, c01 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ) + NR1, 0xFFFF, d0 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 0 ) ) + NR1, 0xFFFF, c01 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 1 ) ) + NR1, 0xFFFF, d0 );
|
||||
|
||||
kr_new += 3;
|
||||
}
|
||||
@@ -96,18 +96,18 @@ void unpackb_nr48_bf16bf16f32of32_row_major
|
||||
|
||||
a01 = _mm512_permutex2var_epi16( b0, selector_even, a0 );
|
||||
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ), a01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ), a01 );
|
||||
|
||||
c0 = _mm512_loadu_si512( b + ( ( kr_new + 2 ) * NR1 ) );
|
||||
c01 = _mm512_permutex2var_epi16( c0, selector_even, c0 );
|
||||
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ) + NR1, 0xFFFF, c01 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ) + NR1, 0xFFFF, c01 );
|
||||
}
|
||||
}
|
||||
void unpackb_nr32_bf16bf16f32of32_row_major
|
||||
(
|
||||
const bfloat16* b,
|
||||
bfloat16* unpack_b_buffer_bf16bf16f32of32,
|
||||
bfloat16* unpack_b_buffer,
|
||||
const dim_t KC,
|
||||
dim_t ldb
|
||||
)
|
||||
@@ -138,8 +138,8 @@ void unpackb_nr32_bf16bf16f32of32_row_major
|
||||
c0 = _mm512_permutex2var_epi16( c0, selector_odd, a0 );
|
||||
|
||||
// Store to unpack buffer
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ), a01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ), c0 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 0 ) ), a01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 1 ) ), c0 );
|
||||
|
||||
}
|
||||
if( k_partial_pieces > 0 )
|
||||
@@ -150,13 +150,13 @@ void unpackb_nr32_bf16bf16f32of32_row_major
|
||||
a0 = _mm512_permutex2var_epi16( c0, selector_even, a0 );
|
||||
|
||||
// Store to unpack buffer
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ), a0 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ), a0 );
|
||||
}
|
||||
}
|
||||
void unpackb_nr16_bf16bf16f32of32_row_major
|
||||
(
|
||||
const bfloat16* b,
|
||||
bfloat16* unpack_b_buffer_bf16bf16f32of32,
|
||||
bfloat16* unpack_b_buffer,
|
||||
const dim_t KC,
|
||||
dim_t ldb
|
||||
)
|
||||
@@ -187,8 +187,8 @@ void unpackb_nr16_bf16bf16f32of32_row_major
|
||||
c0 = _mm512_permutex2var_epi16( a0, selector_odd, a0 );
|
||||
|
||||
// Store to unpack buffer
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ), 0xFFFF, a01 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ), 0xFFFF, c0 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 0 ) ), 0xFFFF, a01 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 1 ) ), 0xFFFF, c0 );
|
||||
}
|
||||
if( k_partial_pieces > 0 )
|
||||
{
|
||||
@@ -197,13 +197,13 @@ void unpackb_nr16_bf16bf16f32of32_row_major
|
||||
a0 = _mm512_permutex2var_epi16( a0, selector_even, a0 );
|
||||
|
||||
// Store to unpack buffer
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ), 0xFFFF, a0 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ), 0xFFFF, a0 );
|
||||
}
|
||||
}
|
||||
void unpackb_nrlt16_bf16bf16f32of32_row_major
|
||||
(
|
||||
const bfloat16* b,
|
||||
bfloat16* unpack_b_buffer_bf16bf16f32of32,
|
||||
bfloat16* unpack_b_buffer,
|
||||
const dim_t KC,
|
||||
dim_t ldb,
|
||||
dim_t n0_partial_rem
|
||||
@@ -237,8 +237,8 @@ void unpackb_nrlt16_bf16bf16f32of32_row_major
|
||||
c0 = _mm512_permutex2var_epi16( a0, selector_odd, a0 );
|
||||
|
||||
// Store to unpack buffer
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ), store_mask, a01 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ), store_mask, c0 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 0 ) ), store_mask, a01 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( kr + 1 ) ), store_mask, c0 );
|
||||
}
|
||||
if( k_partial_pieces > 0 )
|
||||
{
|
||||
@@ -247,14 +247,14 @@ void unpackb_nrlt16_bf16bf16f32of32_row_major
|
||||
a0 = _mm512_permutex2var_epi16( a0, selector_even, a0 );
|
||||
|
||||
// Store to unpack buffer
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ), store_mask, a0 );
|
||||
_mm512_mask_storeu_epi16( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ), store_mask, a0 );
|
||||
}
|
||||
}
|
||||
|
||||
void unpackb_nr64_bf16bf16f32of32_row_major
|
||||
(
|
||||
const bfloat16* b,
|
||||
bfloat16* unpack_b_buffer_bf16bf16f32of32,
|
||||
bfloat16* unpack_b_buffer,
|
||||
const dim_t NC,
|
||||
const dim_t KC,
|
||||
dim_t ldb
|
||||
@@ -304,10 +304,10 @@ void unpackb_nr64_bf16bf16f32of32_row_major
|
||||
d0 = _mm512_permutex2var_epi16( d0, selector_odd, c0 );
|
||||
|
||||
// Store to unpack buffer
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ) + jc, a01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 0 ) ) + jc + 32, c01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ) + jc, b0 );
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( kr + 1 ) ) + jc + 32, d0 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 0 ) ) + jc, a01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 0 ) ) + jc + 32, c01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 1 ) ) + jc, b0 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( kr + 1 ) ) + jc + 32, d0 );
|
||||
|
||||
}
|
||||
if( k_partial_pieces > 0 )
|
||||
@@ -322,8 +322,8 @@ void unpackb_nr64_bf16bf16f32of32_row_major
|
||||
c01 = _mm512_permutex2var_epi16( d0, selector_even, c0 );
|
||||
|
||||
// Store to unpack buffer
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ) + jc, a01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer_bf16bf16f32of32 + ( ldb * ( k_full_pieces + 0 ) ) + jc + 32, c01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ) + jc, a01 );
|
||||
_mm512_storeu_si512( unpack_b_buffer + ( ldb * ( k_full_pieces + 0 ) ) + jc + 32, c01 );
|
||||
}
|
||||
}
|
||||
|
||||
@@ -344,7 +344,7 @@ void unpackb_nr64_bf16bf16f32of32_row_major
|
||||
unpackb_nr48_bf16bf16f32of32_row_major
|
||||
(
|
||||
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
|
||||
( unpack_b_buffer_bf16bf16f32of32 + n_full_pieces_loop_limit ), KC, ldb
|
||||
( unpack_b_buffer + n_full_pieces_loop_limit ), KC, ldb
|
||||
);
|
||||
|
||||
n0_partial_unpack = 48;
|
||||
@@ -354,7 +354,7 @@ void unpackb_nr64_bf16bf16f32of32_row_major
|
||||
unpackb_nr32_bf16bf16f32of32_row_major
|
||||
(
|
||||
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
|
||||
( unpack_b_buffer_bf16bf16f32of32 + n_full_pieces_loop_limit ), KC, ldb
|
||||
( unpack_b_buffer + n_full_pieces_loop_limit ), KC, ldb
|
||||
);
|
||||
|
||||
n0_partial_unpack = 32;
|
||||
@@ -364,7 +364,7 @@ void unpackb_nr64_bf16bf16f32of32_row_major
|
||||
unpackb_nr16_bf16bf16f32of32_row_major
|
||||
(
|
||||
( b + ( n_full_pieces_loop_limit * KC_updated ) ),
|
||||
( unpack_b_buffer_bf16bf16f32of32 + n_full_pieces_loop_limit ), KC, ldb
|
||||
( unpack_b_buffer + n_full_pieces_loop_limit ), KC, ldb
|
||||
);
|
||||
|
||||
n0_partial_unpack = 16;
|
||||
@@ -376,7 +376,7 @@ void unpackb_nr64_bf16bf16f32of32_row_major
|
||||
(
|
||||
( b + ( n_full_pieces_loop_limit * KC_updated ) +
|
||||
( n0_partial_unpack * KC_updated ) ),
|
||||
( unpack_b_buffer_bf16bf16f32of32 + n_full_pieces_loop_limit + n0_partial_unpack ), KC, ldb,
|
||||
( unpack_b_buffer + n_full_pieces_loop_limit + n0_partial_unpack ), KC, ldb,
|
||||
n0_partial_rem
|
||||
);
|
||||
}
|
||||
@@ -895,7 +895,7 @@ void unpackb_nr64_bf16bf16f32of32_col_major
|
||||
void unpackb_nr64_bf16bf16f32of32
|
||||
(
|
||||
const bfloat16* b,
|
||||
bfloat16* unpack_b_buffer_bf16bf16f32of32,
|
||||
bfloat16* unpack_b_buffer,
|
||||
const dim_t NC,
|
||||
const dim_t KC,
|
||||
dim_t rs_b,
|
||||
@@ -904,11 +904,11 @@ void unpackb_nr64_bf16bf16f32of32
|
||||
{
|
||||
if( cs_b == 1 )
|
||||
{
|
||||
unpackb_nr64_bf16bf16f32of32_row_major( b, unpack_b_buffer_bf16bf16f32of32, NC, KC, rs_b );
|
||||
unpackb_nr64_bf16bf16f32of32_row_major( b, unpack_b_buffer, NC, KC, rs_b );
|
||||
}
|
||||
else
|
||||
{
|
||||
unpackb_nr64_bf16bf16f32of32_col_major( b, unpack_b_buffer_bf16bf16f32of32, NC, KC, cs_b );
|
||||
unpackb_nr64_bf16bf16f32of32_col_major( b, unpack_b_buffer, NC, KC, cs_b );
|
||||
}
|
||||
}
|
||||
#endif // BLIS_ADDON_LPGEMM
|
||||
|
||||
Reference in New Issue
Block a user