Bug Fix in BF16 AVX2 conversion path (#236) (#241)

- In the current implementation of bf16 to f32 conversion for packed data
 we handle both GEMM and GEMV conditions in the same function separated
 with conditions.
 - But, when n = (NC+1) the function would execute GEMV conversion logic
 and write back the data inaccurately leading to accuracy issues.
 - Hence, modified the convert function and reorder functions to have
 separate conversion logic to make it cleaner and avoid confusions.
 -  Also, updated the API calls to adhere to the changes appropriately.

[AMD-Internal: CPUPL-7540]
This commit is contained in:
V, Varsha
2025-10-24 16:39:41 +05:30
committed by GitHub
parent 64d8d06aad
commit 4178e3c5ff
4 changed files with 309 additions and 274 deletions

View File

@@ -891,21 +891,18 @@ LPGEMV_AVX2(bfloat16, bfloat16, float, bf16bf16f32of32)
{
/* For n = 1 case, a re-ordered matrix would be stored contigously
in memeory and hence need to be accessed likewise for conversion.*/
unpackb_nr64_bf16_f32
unpackb_nr64_bf16_f32_gemv
(
b, cvt_b_buffer_bf16_f32,
k, 1,
rs_b, cs_b, TRUE
b, cvt_b_buffer_bf16_f32, k
);
}
else
{
cvt_bf16_f32
// Direct call to optimized GEMV conversion (K=1, contiguous output)
cvt_bf16_f32_gemv_row_major
(
cvt_b_buffer_bf16_f32 ,
b, rs_b, cs_b,
k, 1,
rs_b_use, cs_b_use
cvt_b_buffer_bf16_f32,
b, rs_b, k
);
}
b_use = cvt_b_buffer_bf16_f32;
@@ -1084,7 +1081,7 @@ LPGEMV_AVX2(bfloat16, bfloat16, float, bf16bf16f32of32)
( jc_cur_loop_rem * kc0_updated) + ( n_sub_updated * pc )),
( b_unreorder + (nc0 * pc)),
kc0, nc0 ,
rs_b_use, cs_b_use, FALSE
rs_b_use, cs_b_use
);
b_use = b_unreorder;
}
@@ -1431,7 +1428,7 @@ LPGEMM_5LOOP_AVX2(bfloat16,bfloat16,float,bf16bf16f32of32)
( ( jc_cur_loop_rem + jc_packb_start ) *
kc0_updated ) ),( b_unreorder + jc_packb_start),
kc0, ( jc_packb_end - jc_packb_start ) ,
rs_b_use, cs_b_use, FALSE
rs_b_use, cs_b_use
);
}

View File

@@ -195,8 +195,7 @@ void unpackb_nr64_bf16_f32
const dim_t KC,
const dim_t NC,
dim_t rs_b,
dim_t cs_b,
bool is_n_one
dim_t cs_b
);
void cvt_bf16_f32(
@@ -210,4 +209,21 @@ void cvt_bf16_f32(
const dim_t cs_p
);
// Optimized GEMV conversion for true K=1 matrices with contiguous output
void cvt_bf16_f32_gemv_row_major(
float* cvt_buffer,
const bfloat16* a,
const dim_t rs_a,
const dim_t MC
);
// Optimized GEMV unpacking for true N=1 reordered matrices (contiguous storage)
void
unpackb_nr64_bf16_f32_gemv(
const bfloat16* b,
float* unpack_b_buffer,
const dim_t KC
);
#endif //BLIS_GEMM_BF16_PACKB

View File

@@ -53,6 +53,47 @@
#define LOAD_AND_CONVERT_BF16_F32(reg, ic ) \
reg = CVT_BF16_F32_SHIFT_AVX2( (__m128i)_mm_loadu_si128( \
(const __m128i*)( a + ( ic * rs_a ) + ( kr * cs_a ) ) ) );
// GEMV conversion for true K=1 matrices(for B matrix) with contiguous output
// This is a fast path for matrix-vector operations where the output is stored
// contiguously
void
cvt_bf16_f32_gemv_row_major
(
float* cvt_buffer,
const bfloat16* a,
const dim_t rs_a,
const dim_t MC
)
{
/* For true GEMV (K=1 matrices with contiguous output):
If A-matrix is col-major MC = k due to swapping,
if B-matrix is row-major MC = k.
This function stores converted values contiguously in memory. */
__m256 a_reg;
dim_t m0;
__m256i store_mask;
// Process 8 elements at a time
for (m0 = 0; (m0 + 8) < MC; m0 += 8) {
bfloat16 buff[8] = { 0 };
for (int i = 0; i < 8; i++)
buff[i] = (*(a + (m0 + i) * rs_a));
a_reg = CVT_BF16_F32_SHIFT_AVX2(
(__m128i)_mm_loadu_si128((const __m128i*)(buff)));
_mm256_storeu_ps((cvt_buffer + m0), a_reg);
}
// Handle remaining elements (< 8)
if (m0 < MC) {
bfloat16 buff[8] = { 0 };
for (int i = 0; i < (MC - m0); i++)
buff[i] = (*(a + (m0 + i) * rs_a));
a_reg = CVT_BF16_F32_SHIFT_AVX2(
(__m128i)_mm_loadu_si128((const __m128i*)(buff)));
GET_STORE_MASK((MC - m0), store_mask);
_mm256_maskstore_ps((cvt_buffer + m0), store_mask, a_reg);
}
}
void cvt_bf16_f32_row_major
(
@@ -66,251 +107,225 @@ void cvt_bf16_f32_row_major
const dim_t cs_p
)
{
if( KC != 1)
dim_t MR = 16;
dim_t k_left = KC % 8;
__m256i store_mask;
__m256 a_reg[16];
dim_t ic = 0, kr = 0;
for( ic = 0; ( ic + MR - 1 ) < MC; ic += MR )
{
dim_t MR = 16;
dim_t k_left = KC % 8;
__m256i store_mask;
__m256 a_reg[16];
dim_t ic = 0, kr = 0;
for( ic = 0; ( ic + MR - 1 ) < MC; ic += MR )
for( kr = 0; ( kr + 8 - 1) < KC; kr += 8 )
{
for( kr = 0; ( kr + 8 - 1) < KC; kr += 8 )
{
/*Load 8 BF16 elements from 16 rows, and convert them to F32 elements*/
LOAD_AND_CONVERT_BF16_F32(a_reg[0], ( ic + 0 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[1], ( ic + 1 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[2], ( ic + 2 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[3], ( ic + 3 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[4], ( ic + 4 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[5], ( ic + 5 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[6], ( ic + 6 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[7], ( ic + 7 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[8], ( ic + 8 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[9], ( ic + 9 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[10], ( ic + 10 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[11], ( ic + 11 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[12], ( ic + 12 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[13], ( ic + 13 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[14], ( ic + 14 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[15], ( ic + 15 ) );
/*Load 8 BF16 elements from 16 rows, and convert them to F32 elements*/
LOAD_AND_CONVERT_BF16_F32(a_reg[0], ( ic + 0 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[1], ( ic + 1 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[2], ( ic + 2 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[3], ( ic + 3 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[4], ( ic + 4 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[5], ( ic + 5 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[6], ( ic + 6 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[7], ( ic + 7 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[8], ( ic + 8 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[9], ( ic + 9 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[10], ( ic + 10 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[11], ( ic + 11 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[12], ( ic + 12 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[13], ( ic + 13 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[14], ( ic + 14 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[15], ( ic + 15 ) );
/*Store 8 F32 elements each in 16 rows */
_mm256_storeu_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , a_reg[0] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr , a_reg[1] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr , a_reg[2] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr , a_reg[3] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 4 ) * rs_p ) + kr , a_reg[4] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 5 ) * rs_p ) + kr , a_reg[5] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 6 ) * rs_p ) + kr , a_reg[6] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 7 ) * rs_p ) + kr , a_reg[7] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 8 ) * rs_p ) + kr , a_reg[8] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 9 ) * rs_p ) + kr , a_reg[9] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 10 ) * rs_p ) + kr , a_reg[10] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 11 ) * rs_p ) + kr , a_reg[11] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 12 ) * rs_p ) + kr , a_reg[12] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 13 ) * rs_p ) + kr , a_reg[13] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 14 ) * rs_p ) + kr , a_reg[14] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 15 ) * rs_p ) + kr , a_reg[15] );
}
if( k_left > 0)
{
/*Using a data_feeder function to load < 8 elemnts and convert
to f32 elements*/
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],k_left,( ic + 0 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[1],k_left,( ic + 1 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[2],k_left,( ic + 2 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[3],k_left,( ic + 3 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[4],k_left,( ic + 4 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[5],k_left,( ic + 5 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[6],k_left,( ic + 6 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[7],k_left,( ic + 7 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[8],k_left,( ic + 8 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[9],k_left,( ic + 9 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[10],k_left,( ic + 10 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[11],k_left,( ic + 11 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[12],k_left,( ic + 12 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[13],k_left,( ic + 13 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[14],k_left,( ic + 14 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[15],k_left,( ic + 15 ), kr);
GET_STORE_MASK(k_left, store_mask);
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr, store_mask, a_reg[0] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr, store_mask, a_reg[1] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr, store_mask, a_reg[2] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr, store_mask, a_reg[3] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 4 ) * rs_p ) + kr, store_mask, a_reg[4] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 5 ) * rs_p ) + kr, store_mask, a_reg[5] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 6 ) * rs_p ) + kr, store_mask, a_reg[6] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 7 ) * rs_p ) + kr, store_mask, a_reg[7] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 8 ) * rs_p ) + kr, store_mask, a_reg[8] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 9 ) * rs_p ) + kr, store_mask, a_reg[9] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 10 ) * rs_p ) + kr, store_mask, a_reg[10] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 11 ) * rs_p ) + kr, store_mask, a_reg[11] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 12 ) * rs_p ) + kr, store_mask, a_reg[12] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 13 ) * rs_p ) + kr, store_mask, a_reg[13] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 14 ) * rs_p ) + kr, store_mask, a_reg[14] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 15 ) * rs_p ) + kr, store_mask, a_reg[15] );
}
/*Store 8 F32 elements each in 16 rows */
_mm256_storeu_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , a_reg[0] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr , a_reg[1] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr , a_reg[2] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr , a_reg[3] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 4 ) * rs_p ) + kr , a_reg[4] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 5 ) * rs_p ) + kr , a_reg[5] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 6 ) * rs_p ) + kr , a_reg[6] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 7 ) * rs_p ) + kr , a_reg[7] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 8 ) * rs_p ) + kr , a_reg[8] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 9 ) * rs_p ) + kr , a_reg[9] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 10 ) * rs_p ) + kr , a_reg[10] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 11 ) * rs_p ) + kr , a_reg[11] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 12 ) * rs_p ) + kr , a_reg[12] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 13 ) * rs_p ) + kr , a_reg[13] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 14 ) * rs_p ) + kr , a_reg[14] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 15 ) * rs_p ) + kr , a_reg[15] );
}
for( ; ( ic + 8 - 1 ) < MC; ic += 8 )
if( k_left > 0)
{
for( kr = 0; ( kr + 8 - 1 ) < KC; kr += 8 )
{
LOAD_AND_CONVERT_BF16_F32(a_reg[0], ( ic + 0 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[1], ( ic + 1 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[2], ( ic + 2 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[3], ( ic + 3 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[4], ( ic + 4 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[5], ( ic + 5 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[6], ( ic + 6 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[7], ( ic + 7 ) );
/*Using a data_feeder function to load < 8 elemnts and convert
to f32 elements*/
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],k_left,( ic + 0 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[1],k_left,( ic + 1 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[2],k_left,( ic + 2 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[3],k_left,( ic + 3 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[4],k_left,( ic + 4 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[5],k_left,( ic + 5 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[6],k_left,( ic + 6 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[7],k_left,( ic + 7 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[8],k_left,( ic + 8 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[9],k_left,( ic + 9 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[10],k_left,( ic + 10 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[11],k_left,( ic + 11 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[12],k_left,( ic + 12 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[13],k_left,( ic + 13 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[14],k_left,( ic + 14 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[15],k_left,( ic + 15 ), kr);
_mm256_storeu_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , a_reg[0] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr , a_reg[1] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr , a_reg[2] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr , a_reg[3] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 4 ) * rs_p ) + kr , a_reg[4] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 5 ) * rs_p ) + kr , a_reg[5] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 6 ) * rs_p ) + kr , a_reg[6] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 7 ) * rs_p ) + kr , a_reg[7] );
}
if(k_left)
{
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],k_left,( ic + 0 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[1],k_left,( ic + 1 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[2],k_left,( ic + 2 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[3],k_left,( ic + 3 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[4],k_left,( ic + 4 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[5],k_left,( ic + 5 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[6],k_left,( ic + 6 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[7],k_left,( ic + 7 ), kr);
GET_STORE_MASK(k_left, store_mask);
GET_STORE_MASK(k_left, store_mask);
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr, store_mask, a_reg[0] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr, store_mask, a_reg[1] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr, store_mask , a_reg[2] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr, store_mask , a_reg[3] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 4 ) * rs_p ) + kr, store_mask , a_reg[4] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 5 ) * rs_p ) + kr, store_mask , a_reg[5] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 6 ) * rs_p ) + kr, store_mask , a_reg[6] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 7 ) * rs_p ) + kr, store_mask , a_reg[7] );
}
}
for( ; ( ic + 4 - 1 ) < MC; ic += 4 )
{
for( kr = 0; ( kr + 8 - 1 ) < KC; kr += 8 )
{
LOAD_AND_CONVERT_BF16_F32(a_reg[0], ( ic + 0 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[1], ( ic + 1 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[2], ( ic + 2 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[3], ( ic + 3 ) );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , a_reg[0] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr , a_reg[1] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr , a_reg[2] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr , a_reg[3] );
}
if( k_left > 0 )
{
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],k_left,( ic + 0 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[1],k_left,( ic + 1 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[2],k_left,( ic + 2 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[3],k_left,( ic + 3 ), kr);
GET_STORE_MASK(k_left, store_mask);
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr, store_mask, a_reg[0] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr, store_mask, a_reg[1] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr, store_mask , a_reg[2] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr, store_mask , a_reg[3] );
}
}
for( ; ( ic + 2 - 1 ) < MC; ic += 2 )
{
for( kr = 0; ( kr + 8 - 1 ) < KC; kr += 8 )
{
LOAD_AND_CONVERT_BF16_F32(a_reg[0], ( ic + 0 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[1], ( ic + 1 ) );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , a_reg[0] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr , a_reg[1] );
}
if( k_left > 0 )
{
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],k_left,( ic + 0 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[1],k_left,( ic + 1 ), kr);
GET_STORE_MASK(k_left, store_mask);
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr, store_mask, a_reg[0] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr, store_mask, a_reg[1] );
}
}
for( ; ( ic ) < MC; ic += 1 )
{
for( kr = 0; ( kr + 8 - 1 ) < KC; kr += 8 )
{
LOAD_AND_CONVERT_BF16_F32(a_reg[0], ( ic + 0 ) );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , a_reg[0] );
}
for( ; ( kr + 4 - 1 ) < KC; kr += 4 )
{
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],4,( ic + 0 ), kr);
GET_STORE_MASK(4, store_mask);
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , store_mask, a_reg[0] );
}
for( ; ( kr + 2 - 1 ) < KC; kr += 2 )
{
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],2,( ic + 0 ), kr);
GET_STORE_MASK(2, store_mask);
_mm256_maskstore_ps ( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , store_mask, a_reg[0] );
}
for( ; ( kr ) < KC; kr += 1 )
{
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],1,( ic + 0 ), kr);
GET_STORE_MASK(2, store_mask);
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , store_mask, a_reg[0] );
}
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr, store_mask, a_reg[0] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr, store_mask, a_reg[1] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr, store_mask, a_reg[2] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr, store_mask, a_reg[3] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 4 ) * rs_p ) + kr, store_mask, a_reg[4] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 5 ) * rs_p ) + kr, store_mask, a_reg[5] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 6 ) * rs_p ) + kr, store_mask, a_reg[6] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 7 ) * rs_p ) + kr, store_mask, a_reg[7] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 8 ) * rs_p ) + kr, store_mask, a_reg[8] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 9 ) * rs_p ) + kr, store_mask, a_reg[9] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 10 ) * rs_p ) + kr, store_mask, a_reg[10] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 11 ) * rs_p ) + kr, store_mask, a_reg[11] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 12 ) * rs_p ) + kr, store_mask, a_reg[12] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 13 ) * rs_p ) + kr, store_mask, a_reg[13] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 14 ) * rs_p ) + kr, store_mask, a_reg[14] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 15 ) * rs_p ) + kr, store_mask, a_reg[15] );
}
}
else
for( ; ( ic + 8 - 1 ) < MC; ic += 8 )
{
/* If A-matrix is col-major MC = k due to swapping
of matrix and if B-matrix is row-major MC = k .*/
__m256 a_reg;
dim_t m0;
__m256i store_mask;
for( m0 = 0; ( m0 + 8 ) < MC; m0 += 8 )
for( kr = 0; ( kr + 8 - 1 ) < KC; kr += 8 )
{
bfloat16 buff[8] = {0};
for( int i = 0; i < 8; i++ ) buff[i] = (*( a + (m0 + i) * rs_a) );
a_reg = CVT_BF16_F32_SHIFT_AVX2( (__m128i)_mm_loadu_si128( \
(const __m128i*)( buff ) ) );
_mm256_storeu_ps( ( cvt_buffer + m0 ), a_reg );
LOAD_AND_CONVERT_BF16_F32(a_reg[0], ( ic + 0 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[1], ( ic + 1 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[2], ( ic + 2 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[3], ( ic + 3 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[4], ( ic + 4 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[5], ( ic + 5 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[6], ( ic + 6 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[7], ( ic + 7 ) );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , a_reg[0] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr , a_reg[1] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr , a_reg[2] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr , a_reg[3] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 4 ) * rs_p ) + kr , a_reg[4] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 5 ) * rs_p ) + kr , a_reg[5] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 6 ) * rs_p ) + kr , a_reg[6] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 7 ) * rs_p ) + kr , a_reg[7] );
}
if(k_left)
{
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],k_left,( ic + 0 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[1],k_left,( ic + 1 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[2],k_left,( ic + 2 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[3],k_left,( ic + 3 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[4],k_left,( ic + 4 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[5],k_left,( ic + 5 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[6],k_left,( ic + 6 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[7],k_left,( ic + 7 ), kr);
GET_STORE_MASK(k_left, store_mask);
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr, store_mask, a_reg[0] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr, store_mask, a_reg[1] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr, store_mask , a_reg[2] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr, store_mask , a_reg[3] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 4 ) * rs_p ) + kr, store_mask , a_reg[4] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 5 ) * rs_p ) + kr, store_mask , a_reg[5] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 6 ) * rs_p ) + kr, store_mask , a_reg[6] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 7 ) * rs_p ) + kr, store_mask , a_reg[7] );
}
}
for( ; ( ic + 4 - 1 ) < MC; ic += 4 )
{
for( kr = 0; ( kr + 8 - 1 ) < KC; kr += 8 )
{
LOAD_AND_CONVERT_BF16_F32(a_reg[0], ( ic + 0 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[1], ( ic + 1 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[2], ( ic + 2 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[3], ( ic + 3 ) );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , a_reg[0] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr , a_reg[1] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr , a_reg[2] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr , a_reg[3] );
}
if( k_left > 0 )
{
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],k_left,( ic + 0 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[1],k_left,( ic + 1 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[2],k_left,( ic + 2 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[3],k_left,( ic + 3 ), kr);
GET_STORE_MASK(k_left, store_mask);
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr, store_mask, a_reg[0] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr, store_mask, a_reg[1] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr, store_mask , a_reg[2] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr, store_mask , a_reg[3] );
}
}
for( ; ( ic + 2 - 1 ) < MC; ic += 2 )
{
for( kr = 0; ( kr + 8 - 1 ) < KC; kr += 8 )
{
LOAD_AND_CONVERT_BF16_F32(a_reg[0], ( ic + 0 ) );
LOAD_AND_CONVERT_BF16_F32(a_reg[1], ( ic + 1 ) );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , a_reg[0] );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr , a_reg[1] );
}
if( k_left > 0 )
{
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],k_left,( ic + 0 ), kr);
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[1],k_left,( ic + 1 ), kr);
GET_STORE_MASK(k_left, store_mask);
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr, store_mask, a_reg[0] );
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr, store_mask, a_reg[1] );
}
}
for( ; ( ic ) < MC; ic += 1 )
{
for( kr = 0; ( kr + 8 - 1 ) < KC; kr += 8 )
{
LOAD_AND_CONVERT_BF16_F32(a_reg[0], ( ic + 0 ) );
_mm256_storeu_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , a_reg[0] );
}
for( ; ( kr + 4 - 1 ) < KC; kr += 4 )
{
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],4,( ic + 0 ), kr);
GET_STORE_MASK(4, store_mask);
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , store_mask, a_reg[0] );
}
for( ; ( kr + 2 - 1 ) < KC; kr += 2 )
{
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],2,( ic + 0 ), kr);
GET_STORE_MASK(2, store_mask);
_mm256_maskstore_ps ( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , store_mask, a_reg[0] );
}
for( ; ( kr ) < KC; kr += 1 )
{
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],1,( ic + 0 ), kr);
GET_STORE_MASK(2, store_mask);
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , store_mask, a_reg[0] );
}
bfloat16 buff[8] = {0};
for( int i = 0; i < (MC - m0); i++ ) buff[i] = (*( a + (m0 + i)*rs_a ) );
a_reg = CVT_BF16_F32_SHIFT_AVX2( (__m128i)_mm_loadu_si128( \
(const __m128i*)( buff ) ) );
GET_STORE_MASK((MC - m0), store_mask);
_mm256_maskstore_ps( ( cvt_buffer + m0 ), store_mask, a_reg );
}
}

View File

@@ -584,6 +584,39 @@ void unpackb_nr64_bf16_f32_row_major
}
void
unpackb_nr64_bf16_f32_gemv(const bfloat16* b,
float* unpack_b_buffer,
const dim_t KC)
{
/* For true GEMV (N=1 reordered matrices): contiguous storage
In the N=1 case, the reordered matrix is stored contiguously,
so we just need to convert bf16 to f32 without unpacking. */
__m256 a_reg;
dim_t k0 = 0;
__m256i store_mask;
// Process 8 elements at a time
for (; (k0 + 8) < KC; k0 += 8) {
a_reg = CVT_BF16_F32_SHIFT_AVX2(
(__m128i)_mm_loadu_si128((const __m128i*)((b + k0))));
_mm256_storeu_ps((unpack_b_buffer + k0), a_reg);
}
// Handle remaining elements (< 8)
dim_t k_left = (KC - k0);
if (k_left > 0) {
bfloat16 buff[8] = { 0 };
for (int i = 0; i < k_left; i++)
buff[i] = (*(b + (k0 + i)));
a_reg = CVT_BF16_F32_SHIFT_AVX2(
(__m128i)_mm_loadu_si128((const __m128i*)(buff)));
GET_STORE_MASK(k_left, store_mask);
_mm256_maskstore_ps((unpack_b_buffer + k0), store_mask, a_reg);
}
}
void unpackb_nr64_bf16_f32
(
const bfloat16* b,
@@ -591,36 +624,10 @@ void unpackb_nr64_bf16_f32
const dim_t KC,
const dim_t NC,
dim_t rs_b,
dim_t cs_b,
bool is_n_one
dim_t cs_b
)
{
if( is_n_one == TRUE )
{
__m256 a_reg;
dim_t k0 = 0;
__m256i store_mask;
for( ; ( k0 + 8 ) < KC; k0 += 8 )
{
a_reg = CVT_BF16_F32_SHIFT_AVX2( (__m128i)_mm_loadu_si128( \
(const __m128i*)(( b + k0 ) ) ) );
_mm256_storeu_ps( ( unpack_b_buffer + k0 ), a_reg );
}
dim_t k_left = (KC - k0);
if( k_left > 0 )
{
bfloat16 buff[8] = {0};
for( int i = 0; i < k_left; i++ )buff[i] = ( *( b + (k0 + i) ) );
a_reg = CVT_BF16_F32_SHIFT_AVX2( (__m128i)_mm_loadu_si128( \
(const __m128i*)( buff ) ) );
GET_STORE_MASK(k_left, store_mask);
_mm256_maskstore_ps( ( unpack_b_buffer + k0 ), store_mask, a_reg );
}
}
else if( cs_b == 1 )
if( cs_b == 1 )
{
unpackb_nr64_bf16_f32_row_major( b, unpack_b_buffer, NC, KC, rs_b );
}