mirror of
https://github.com/amd/blis.git
synced 2026-04-19 23:28:52 +00:00
Bug Fix in BF16 AVX2 conversion path (#236)
- In the current implementation of bf16 to f32 conversion for packed data we handle both GEMM and GEMV conditions in the same function separated with conditions. - But, when n = (NC+1) the function would execute GEMV conversion logic and write back the data inaccurately leading to accuracy issues. - Hence, modified the convert function and reorder functions to have separate conversion logic to make it cleaner and avoid confusions. - Also, updated the API calls to adhere to the changes appropriately. [AMD-Internal: CPUPL-7540]
This commit is contained in:
@@ -891,21 +891,18 @@ LPGEMV_AVX2(bfloat16, bfloat16, float, bf16bf16f32of32)
|
||||
{
|
||||
/* For n = 1 case, a re-ordered matrix would be stored contigously
|
||||
in memeory and hence need to be accessed likewise for conversion.*/
|
||||
unpackb_nr64_bf16_f32
|
||||
unpackb_nr64_bf16_f32_gemv
|
||||
(
|
||||
b, cvt_b_buffer_bf16_f32,
|
||||
k, 1,
|
||||
rs_b, cs_b, TRUE
|
||||
b, cvt_b_buffer_bf16_f32, k
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
cvt_bf16_f32
|
||||
// Direct call to optimized GEMV conversion (K=1, contiguous output)
|
||||
cvt_bf16_f32_gemv_row_major
|
||||
(
|
||||
cvt_b_buffer_bf16_f32 ,
|
||||
b, rs_b, cs_b,
|
||||
k, 1,
|
||||
rs_b_use, cs_b_use
|
||||
cvt_b_buffer_bf16_f32,
|
||||
b, rs_b, k
|
||||
);
|
||||
}
|
||||
b_use = cvt_b_buffer_bf16_f32;
|
||||
@@ -1084,7 +1081,7 @@ LPGEMV_AVX2(bfloat16, bfloat16, float, bf16bf16f32of32)
|
||||
( jc_cur_loop_rem * kc0_updated) + ( n_sub_updated * pc )),
|
||||
( b_unreorder + (nc0 * pc)),
|
||||
kc0, nc0 ,
|
||||
rs_b_use, cs_b_use, FALSE
|
||||
rs_b_use, cs_b_use
|
||||
);
|
||||
b_use = b_unreorder;
|
||||
}
|
||||
@@ -1431,7 +1428,7 @@ LPGEMM_5LOOP_AVX2(bfloat16,bfloat16,float,bf16bf16f32of32)
|
||||
( ( jc_cur_loop_rem + jc_packb_start ) *
|
||||
kc0_updated ) ),( b_unreorder + jc_packb_start),
|
||||
kc0, ( jc_packb_end - jc_packb_start ) ,
|
||||
rs_b_use, cs_b_use, FALSE
|
||||
rs_b_use, cs_b_use
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -195,8 +195,7 @@ void unpackb_nr64_bf16_f32
|
||||
const dim_t KC,
|
||||
const dim_t NC,
|
||||
dim_t rs_b,
|
||||
dim_t cs_b,
|
||||
bool is_n_one
|
||||
dim_t cs_b
|
||||
);
|
||||
|
||||
void cvt_bf16_f32(
|
||||
@@ -210,4 +209,21 @@ void cvt_bf16_f32(
|
||||
const dim_t cs_p
|
||||
);
|
||||
|
||||
// Optimized GEMV conversion for true K=1 matrices with contiguous output
|
||||
void cvt_bf16_f32_gemv_row_major(
|
||||
float* cvt_buffer,
|
||||
const bfloat16* a,
|
||||
const dim_t rs_a,
|
||||
const dim_t MC
|
||||
);
|
||||
|
||||
// Optimized GEMV unpacking for true N=1 reordered matrices (contiguous storage)
|
||||
void
|
||||
unpackb_nr64_bf16_f32_gemv(
|
||||
const bfloat16* b,
|
||||
float* unpack_b_buffer,
|
||||
const dim_t KC
|
||||
);
|
||||
|
||||
|
||||
#endif //BLIS_GEMM_BF16_PACKB
|
||||
|
||||
@@ -53,6 +53,47 @@
|
||||
#define LOAD_AND_CONVERT_BF16_F32(reg, ic ) \
|
||||
reg = CVT_BF16_F32_SHIFT_AVX2( (__m128i)_mm_loadu_si128( \
|
||||
(const __m128i*)( a + ( ic * rs_a ) + ( kr * cs_a ) ) ) );
|
||||
// GEMV conversion for true K=1 matrices(for B matrix) with contiguous output
|
||||
// This is a fast path for matrix-vector operations where the output is stored
|
||||
// contiguously
|
||||
void
|
||||
cvt_bf16_f32_gemv_row_major
|
||||
(
|
||||
float* cvt_buffer,
|
||||
const bfloat16* a,
|
||||
const dim_t rs_a,
|
||||
const dim_t MC
|
||||
)
|
||||
{
|
||||
/* For true GEMV (K=1 matrices with contiguous output):
|
||||
If A-matrix is col-major MC = k due to swapping,
|
||||
if B-matrix is row-major MC = k.
|
||||
This function stores converted values contiguously in memory. */
|
||||
__m256 a_reg;
|
||||
dim_t m0;
|
||||
__m256i store_mask;
|
||||
|
||||
// Process 8 elements at a time
|
||||
for (m0 = 0; (m0 + 8) < MC; m0 += 8) {
|
||||
bfloat16 buff[8] = { 0 };
|
||||
for (int i = 0; i < 8; i++)
|
||||
buff[i] = (*(a + (m0 + i) * rs_a));
|
||||
a_reg = CVT_BF16_F32_SHIFT_AVX2(
|
||||
(__m128i)_mm_loadu_si128((const __m128i*)(buff)));
|
||||
_mm256_storeu_ps((cvt_buffer + m0), a_reg);
|
||||
}
|
||||
|
||||
// Handle remaining elements (< 8)
|
||||
if (m0 < MC) {
|
||||
bfloat16 buff[8] = { 0 };
|
||||
for (int i = 0; i < (MC - m0); i++)
|
||||
buff[i] = (*(a + (m0 + i) * rs_a));
|
||||
a_reg = CVT_BF16_F32_SHIFT_AVX2(
|
||||
(__m128i)_mm_loadu_si128((const __m128i*)(buff)));
|
||||
GET_STORE_MASK((MC - m0), store_mask);
|
||||
_mm256_maskstore_ps((cvt_buffer + m0), store_mask, a_reg);
|
||||
}
|
||||
}
|
||||
|
||||
void cvt_bf16_f32_row_major
|
||||
(
|
||||
@@ -66,251 +107,225 @@ void cvt_bf16_f32_row_major
|
||||
const dim_t cs_p
|
||||
)
|
||||
{
|
||||
if( KC != 1)
|
||||
dim_t MR = 16;
|
||||
dim_t k_left = KC % 8;
|
||||
|
||||
__m256i store_mask;
|
||||
|
||||
__m256 a_reg[16];
|
||||
|
||||
dim_t ic = 0, kr = 0;
|
||||
|
||||
for( ic = 0; ( ic + MR - 1 ) < MC; ic += MR )
|
||||
{
|
||||
dim_t MR = 16;
|
||||
dim_t k_left = KC % 8;
|
||||
|
||||
__m256i store_mask;
|
||||
|
||||
__m256 a_reg[16];
|
||||
|
||||
dim_t ic = 0, kr = 0;
|
||||
|
||||
for( ic = 0; ( ic + MR - 1 ) < MC; ic += MR )
|
||||
for( kr = 0; ( kr + 8 - 1) < KC; kr += 8 )
|
||||
{
|
||||
for( kr = 0; ( kr + 8 - 1) < KC; kr += 8 )
|
||||
{
|
||||
/*Load 8 BF16 elements from 16 rows, and convert them to F32 elements*/
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[0], ( ic + 0 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[1], ( ic + 1 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[2], ( ic + 2 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[3], ( ic + 3 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[4], ( ic + 4 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[5], ( ic + 5 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[6], ( ic + 6 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[7], ( ic + 7 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[8], ( ic + 8 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[9], ( ic + 9 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[10], ( ic + 10 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[11], ( ic + 11 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[12], ( ic + 12 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[13], ( ic + 13 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[14], ( ic + 14 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[15], ( ic + 15 ) );
|
||||
/*Load 8 BF16 elements from 16 rows, and convert them to F32 elements*/
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[0], ( ic + 0 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[1], ( ic + 1 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[2], ( ic + 2 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[3], ( ic + 3 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[4], ( ic + 4 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[5], ( ic + 5 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[6], ( ic + 6 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[7], ( ic + 7 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[8], ( ic + 8 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[9], ( ic + 9 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[10], ( ic + 10 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[11], ( ic + 11 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[12], ( ic + 12 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[13], ( ic + 13 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[14], ( ic + 14 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[15], ( ic + 15 ) );
|
||||
|
||||
/*Store 8 F32 elements each in 16 rows */
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , a_reg[0] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr , a_reg[1] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr , a_reg[2] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr , a_reg[3] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 4 ) * rs_p ) + kr , a_reg[4] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 5 ) * rs_p ) + kr , a_reg[5] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 6 ) * rs_p ) + kr , a_reg[6] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 7 ) * rs_p ) + kr , a_reg[7] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 8 ) * rs_p ) + kr , a_reg[8] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 9 ) * rs_p ) + kr , a_reg[9] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 10 ) * rs_p ) + kr , a_reg[10] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 11 ) * rs_p ) + kr , a_reg[11] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 12 ) * rs_p ) + kr , a_reg[12] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 13 ) * rs_p ) + kr , a_reg[13] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 14 ) * rs_p ) + kr , a_reg[14] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 15 ) * rs_p ) + kr , a_reg[15] );
|
||||
}
|
||||
if( k_left > 0)
|
||||
{
|
||||
/*Using a data_feeder function to load < 8 elemnts and convert
|
||||
to f32 elements*/
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],k_left,( ic + 0 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[1],k_left,( ic + 1 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[2],k_left,( ic + 2 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[3],k_left,( ic + 3 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[4],k_left,( ic + 4 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[5],k_left,( ic + 5 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[6],k_left,( ic + 6 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[7],k_left,( ic + 7 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[8],k_left,( ic + 8 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[9],k_left,( ic + 9 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[10],k_left,( ic + 10 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[11],k_left,( ic + 11 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[12],k_left,( ic + 12 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[13],k_left,( ic + 13 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[14],k_left,( ic + 14 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[15],k_left,( ic + 15 ), kr);
|
||||
|
||||
GET_STORE_MASK(k_left, store_mask);
|
||||
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr, store_mask, a_reg[0] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr, store_mask, a_reg[1] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr, store_mask, a_reg[2] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr, store_mask, a_reg[3] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 4 ) * rs_p ) + kr, store_mask, a_reg[4] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 5 ) * rs_p ) + kr, store_mask, a_reg[5] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 6 ) * rs_p ) + kr, store_mask, a_reg[6] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 7 ) * rs_p ) + kr, store_mask, a_reg[7] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 8 ) * rs_p ) + kr, store_mask, a_reg[8] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 9 ) * rs_p ) + kr, store_mask, a_reg[9] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 10 ) * rs_p ) + kr, store_mask, a_reg[10] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 11 ) * rs_p ) + kr, store_mask, a_reg[11] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 12 ) * rs_p ) + kr, store_mask, a_reg[12] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 13 ) * rs_p ) + kr, store_mask, a_reg[13] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 14 ) * rs_p ) + kr, store_mask, a_reg[14] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 15 ) * rs_p ) + kr, store_mask, a_reg[15] );
|
||||
}
|
||||
/*Store 8 F32 elements each in 16 rows */
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , a_reg[0] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr , a_reg[1] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr , a_reg[2] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr , a_reg[3] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 4 ) * rs_p ) + kr , a_reg[4] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 5 ) * rs_p ) + kr , a_reg[5] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 6 ) * rs_p ) + kr , a_reg[6] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 7 ) * rs_p ) + kr , a_reg[7] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 8 ) * rs_p ) + kr , a_reg[8] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 9 ) * rs_p ) + kr , a_reg[9] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 10 ) * rs_p ) + kr , a_reg[10] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 11 ) * rs_p ) + kr , a_reg[11] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 12 ) * rs_p ) + kr , a_reg[12] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 13 ) * rs_p ) + kr , a_reg[13] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 14 ) * rs_p ) + kr , a_reg[14] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 15 ) * rs_p ) + kr , a_reg[15] );
|
||||
}
|
||||
for( ; ( ic + 8 - 1 ) < MC; ic += 8 )
|
||||
if( k_left > 0)
|
||||
{
|
||||
for( kr = 0; ( kr + 8 - 1 ) < KC; kr += 8 )
|
||||
{
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[0], ( ic + 0 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[1], ( ic + 1 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[2], ( ic + 2 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[3], ( ic + 3 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[4], ( ic + 4 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[5], ( ic + 5 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[6], ( ic + 6 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[7], ( ic + 7 ) );
|
||||
/*Using a data_feeder function to load < 8 elemnts and convert
|
||||
to f32 elements*/
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],k_left,( ic + 0 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[1],k_left,( ic + 1 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[2],k_left,( ic + 2 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[3],k_left,( ic + 3 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[4],k_left,( ic + 4 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[5],k_left,( ic + 5 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[6],k_left,( ic + 6 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[7],k_left,( ic + 7 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[8],k_left,( ic + 8 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[9],k_left,( ic + 9 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[10],k_left,( ic + 10 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[11],k_left,( ic + 11 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[12],k_left,( ic + 12 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[13],k_left,( ic + 13 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[14],k_left,( ic + 14 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[15],k_left,( ic + 15 ), kr);
|
||||
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , a_reg[0] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr , a_reg[1] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr , a_reg[2] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr , a_reg[3] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 4 ) * rs_p ) + kr , a_reg[4] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 5 ) * rs_p ) + kr , a_reg[5] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 6 ) * rs_p ) + kr , a_reg[6] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 7 ) * rs_p ) + kr , a_reg[7] );
|
||||
}
|
||||
if(k_left)
|
||||
{
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],k_left,( ic + 0 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[1],k_left,( ic + 1 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[2],k_left,( ic + 2 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[3],k_left,( ic + 3 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[4],k_left,( ic + 4 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[5],k_left,( ic + 5 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[6],k_left,( ic + 6 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[7],k_left,( ic + 7 ), kr);
|
||||
GET_STORE_MASK(k_left, store_mask);
|
||||
|
||||
GET_STORE_MASK(k_left, store_mask);
|
||||
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr, store_mask, a_reg[0] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr, store_mask, a_reg[1] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr, store_mask , a_reg[2] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr, store_mask , a_reg[3] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 4 ) * rs_p ) + kr, store_mask , a_reg[4] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 5 ) * rs_p ) + kr, store_mask , a_reg[5] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 6 ) * rs_p ) + kr, store_mask , a_reg[6] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 7 ) * rs_p ) + kr, store_mask , a_reg[7] );
|
||||
}
|
||||
}
|
||||
for( ; ( ic + 4 - 1 ) < MC; ic += 4 )
|
||||
{
|
||||
for( kr = 0; ( kr + 8 - 1 ) < KC; kr += 8 )
|
||||
{
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[0], ( ic + 0 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[1], ( ic + 1 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[2], ( ic + 2 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[3], ( ic + 3 ) );
|
||||
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , a_reg[0] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr , a_reg[1] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr , a_reg[2] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr , a_reg[3] );
|
||||
}
|
||||
|
||||
if( k_left > 0 )
|
||||
{
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],k_left,( ic + 0 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[1],k_left,( ic + 1 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[2],k_left,( ic + 2 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[3],k_left,( ic + 3 ), kr);
|
||||
|
||||
GET_STORE_MASK(k_left, store_mask);
|
||||
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr, store_mask, a_reg[0] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr, store_mask, a_reg[1] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr, store_mask , a_reg[2] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr, store_mask , a_reg[3] );
|
||||
}
|
||||
}
|
||||
for( ; ( ic + 2 - 1 ) < MC; ic += 2 )
|
||||
{
|
||||
for( kr = 0; ( kr + 8 - 1 ) < KC; kr += 8 )
|
||||
{
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[0], ( ic + 0 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[1], ( ic + 1 ) );
|
||||
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , a_reg[0] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr , a_reg[1] );
|
||||
}
|
||||
|
||||
if( k_left > 0 )
|
||||
{
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],k_left,( ic + 0 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[1],k_left,( ic + 1 ), kr);
|
||||
|
||||
GET_STORE_MASK(k_left, store_mask);
|
||||
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr, store_mask, a_reg[0] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr, store_mask, a_reg[1] );
|
||||
}
|
||||
}
|
||||
for( ; ( ic ) < MC; ic += 1 )
|
||||
{
|
||||
for( kr = 0; ( kr + 8 - 1 ) < KC; kr += 8 )
|
||||
{
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[0], ( ic + 0 ) );
|
||||
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , a_reg[0] );
|
||||
}
|
||||
for( ; ( kr + 4 - 1 ) < KC; kr += 4 )
|
||||
{
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],4,( ic + 0 ), kr);
|
||||
|
||||
GET_STORE_MASK(4, store_mask);
|
||||
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , store_mask, a_reg[0] );
|
||||
}
|
||||
for( ; ( kr + 2 - 1 ) < KC; kr += 2 )
|
||||
{
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],2,( ic + 0 ), kr);
|
||||
|
||||
GET_STORE_MASK(2, store_mask);
|
||||
|
||||
_mm256_maskstore_ps ( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , store_mask, a_reg[0] );
|
||||
}
|
||||
for( ; ( kr ) < KC; kr += 1 )
|
||||
{
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],1,( ic + 0 ), kr);
|
||||
|
||||
GET_STORE_MASK(2, store_mask);
|
||||
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , store_mask, a_reg[0] );
|
||||
}
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr, store_mask, a_reg[0] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr, store_mask, a_reg[1] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr, store_mask, a_reg[2] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr, store_mask, a_reg[3] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 4 ) * rs_p ) + kr, store_mask, a_reg[4] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 5 ) * rs_p ) + kr, store_mask, a_reg[5] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 6 ) * rs_p ) + kr, store_mask, a_reg[6] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 7 ) * rs_p ) + kr, store_mask, a_reg[7] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 8 ) * rs_p ) + kr, store_mask, a_reg[8] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 9 ) * rs_p ) + kr, store_mask, a_reg[9] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 10 ) * rs_p ) + kr, store_mask, a_reg[10] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 11 ) * rs_p ) + kr, store_mask, a_reg[11] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 12 ) * rs_p ) + kr, store_mask, a_reg[12] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 13 ) * rs_p ) + kr, store_mask, a_reg[13] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 14 ) * rs_p ) + kr, store_mask, a_reg[14] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 15 ) * rs_p ) + kr, store_mask, a_reg[15] );
|
||||
}
|
||||
}
|
||||
else
|
||||
for( ; ( ic + 8 - 1 ) < MC; ic += 8 )
|
||||
{
|
||||
/* If A-matrix is col-major MC = k due to swapping
|
||||
of matrix and if B-matrix is row-major MC = k .*/
|
||||
__m256 a_reg;
|
||||
dim_t m0;
|
||||
__m256i store_mask;
|
||||
|
||||
for( m0 = 0; ( m0 + 8 ) < MC; m0 += 8 )
|
||||
for( kr = 0; ( kr + 8 - 1 ) < KC; kr += 8 )
|
||||
{
|
||||
bfloat16 buff[8] = {0};
|
||||
for( int i = 0; i < 8; i++ ) buff[i] = (*( a + (m0 + i) * rs_a) );
|
||||
a_reg = CVT_BF16_F32_SHIFT_AVX2( (__m128i)_mm_loadu_si128( \
|
||||
(const __m128i*)( buff ) ) );
|
||||
_mm256_storeu_ps( ( cvt_buffer + m0 ), a_reg );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[0], ( ic + 0 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[1], ( ic + 1 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[2], ( ic + 2 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[3], ( ic + 3 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[4], ( ic + 4 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[5], ( ic + 5 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[6], ( ic + 6 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[7], ( ic + 7 ) );
|
||||
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , a_reg[0] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr , a_reg[1] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr , a_reg[2] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr , a_reg[3] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 4 ) * rs_p ) + kr , a_reg[4] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 5 ) * rs_p ) + kr , a_reg[5] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 6 ) * rs_p ) + kr , a_reg[6] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 7 ) * rs_p ) + kr , a_reg[7] );
|
||||
}
|
||||
if(k_left)
|
||||
{
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],k_left,( ic + 0 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[1],k_left,( ic + 1 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[2],k_left,( ic + 2 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[3],k_left,( ic + 3 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[4],k_left,( ic + 4 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[5],k_left,( ic + 5 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[6],k_left,( ic + 6 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[7],k_left,( ic + 7 ), kr);
|
||||
|
||||
GET_STORE_MASK(k_left, store_mask);
|
||||
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr, store_mask, a_reg[0] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr, store_mask, a_reg[1] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr, store_mask , a_reg[2] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr, store_mask , a_reg[3] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 4 ) * rs_p ) + kr, store_mask , a_reg[4] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 5 ) * rs_p ) + kr, store_mask , a_reg[5] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 6 ) * rs_p ) + kr, store_mask , a_reg[6] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 7 ) * rs_p ) + kr, store_mask , a_reg[7] );
|
||||
}
|
||||
}
|
||||
for( ; ( ic + 4 - 1 ) < MC; ic += 4 )
|
||||
{
|
||||
for( kr = 0; ( kr + 8 - 1 ) < KC; kr += 8 )
|
||||
{
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[0], ( ic + 0 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[1], ( ic + 1 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[2], ( ic + 2 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[3], ( ic + 3 ) );
|
||||
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , a_reg[0] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr , a_reg[1] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr , a_reg[2] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr , a_reg[3] );
|
||||
}
|
||||
|
||||
if( k_left > 0 )
|
||||
{
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],k_left,( ic + 0 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[1],k_left,( ic + 1 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[2],k_left,( ic + 2 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[3],k_left,( ic + 3 ), kr);
|
||||
|
||||
GET_STORE_MASK(k_left, store_mask);
|
||||
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr, store_mask, a_reg[0] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr, store_mask, a_reg[1] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 2 ) * rs_p ) + kr, store_mask , a_reg[2] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 3 ) * rs_p ) + kr, store_mask , a_reg[3] );
|
||||
}
|
||||
}
|
||||
for( ; ( ic + 2 - 1 ) < MC; ic += 2 )
|
||||
{
|
||||
for( kr = 0; ( kr + 8 - 1 ) < KC; kr += 8 )
|
||||
{
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[0], ( ic + 0 ) );
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[1], ( ic + 1 ) );
|
||||
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , a_reg[0] );
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr , a_reg[1] );
|
||||
}
|
||||
|
||||
if( k_left > 0 )
|
||||
{
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],k_left,( ic + 0 ), kr);
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[1],k_left,( ic + 1 ), kr);
|
||||
|
||||
GET_STORE_MASK(k_left, store_mask);
|
||||
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr, store_mask, a_reg[0] );
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 1 ) * rs_p ) + kr, store_mask, a_reg[1] );
|
||||
}
|
||||
}
|
||||
for( ; ( ic ) < MC; ic += 1 )
|
||||
{
|
||||
for( kr = 0; ( kr + 8 - 1 ) < KC; kr += 8 )
|
||||
{
|
||||
LOAD_AND_CONVERT_BF16_F32(a_reg[0], ( ic + 0 ) );
|
||||
|
||||
_mm256_storeu_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , a_reg[0] );
|
||||
}
|
||||
for( ; ( kr + 4 - 1 ) < KC; kr += 4 )
|
||||
{
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],4,( ic + 0 ), kr);
|
||||
|
||||
GET_STORE_MASK(4, store_mask);
|
||||
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , store_mask, a_reg[0] );
|
||||
}
|
||||
for( ; ( kr + 2 - 1 ) < KC; kr += 2 )
|
||||
{
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],2,( ic + 0 ), kr);
|
||||
|
||||
GET_STORE_MASK(2, store_mask);
|
||||
|
||||
_mm256_maskstore_ps ( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , store_mask, a_reg[0] );
|
||||
}
|
||||
for( ; ( kr ) < KC; kr += 1 )
|
||||
{
|
||||
CVT_BF16_F32_SHIFT_AVX2_lt8(a_reg[0],1,( ic + 0 ), kr);
|
||||
|
||||
GET_STORE_MASK(2, store_mask);
|
||||
|
||||
_mm256_maskstore_ps( cvt_buffer + ( ( ic + 0 ) * rs_p ) + kr , store_mask, a_reg[0] );
|
||||
}
|
||||
bfloat16 buff[8] = {0};
|
||||
for( int i = 0; i < (MC - m0); i++ ) buff[i] = (*( a + (m0 + i)*rs_a ) );
|
||||
a_reg = CVT_BF16_F32_SHIFT_AVX2( (__m128i)_mm_loadu_si128( \
|
||||
(const __m128i*)( buff ) ) );
|
||||
GET_STORE_MASK((MC - m0), store_mask);
|
||||
_mm256_maskstore_ps( ( cvt_buffer + m0 ), store_mask, a_reg );
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -584,6 +584,39 @@ void unpackb_nr64_bf16_f32_row_major
|
||||
|
||||
}
|
||||
|
||||
void
|
||||
unpackb_nr64_bf16_f32_gemv(const bfloat16* b,
|
||||
float* unpack_b_buffer,
|
||||
const dim_t KC)
|
||||
{
|
||||
/* For true GEMV (N=1 reordered matrices): contiguous storage
|
||||
In the N=1 case, the reordered matrix is stored contiguously,
|
||||
so we just need to convert bf16 to f32 without unpacking. */
|
||||
__m256 a_reg;
|
||||
dim_t k0 = 0;
|
||||
__m256i store_mask;
|
||||
|
||||
// Process 8 elements at a time
|
||||
for (; (k0 + 8) < KC; k0 += 8) {
|
||||
a_reg = CVT_BF16_F32_SHIFT_AVX2(
|
||||
(__m128i)_mm_loadu_si128((const __m128i*)((b + k0))));
|
||||
_mm256_storeu_ps((unpack_b_buffer + k0), a_reg);
|
||||
}
|
||||
|
||||
// Handle remaining elements (< 8)
|
||||
dim_t k_left = (KC - k0);
|
||||
if (k_left > 0) {
|
||||
bfloat16 buff[8] = { 0 };
|
||||
for (int i = 0; i < k_left; i++)
|
||||
buff[i] = (*(b + (k0 + i)));
|
||||
|
||||
a_reg = CVT_BF16_F32_SHIFT_AVX2(
|
||||
(__m128i)_mm_loadu_si128((const __m128i*)(buff)));
|
||||
GET_STORE_MASK(k_left, store_mask);
|
||||
_mm256_maskstore_ps((unpack_b_buffer + k0), store_mask, a_reg);
|
||||
}
|
||||
}
|
||||
|
||||
void unpackb_nr64_bf16_f32
|
||||
(
|
||||
const bfloat16* b,
|
||||
@@ -591,36 +624,10 @@ void unpackb_nr64_bf16_f32
|
||||
const dim_t KC,
|
||||
const dim_t NC,
|
||||
dim_t rs_b,
|
||||
dim_t cs_b,
|
||||
bool is_n_one
|
||||
dim_t cs_b
|
||||
)
|
||||
{
|
||||
if( is_n_one == TRUE )
|
||||
{
|
||||
__m256 a_reg;
|
||||
dim_t k0 = 0;
|
||||
__m256i store_mask;
|
||||
|
||||
for( ; ( k0 + 8 ) < KC; k0 += 8 )
|
||||
{
|
||||
a_reg = CVT_BF16_F32_SHIFT_AVX2( (__m128i)_mm_loadu_si128( \
|
||||
(const __m128i*)(( b + k0 ) ) ) );
|
||||
_mm256_storeu_ps( ( unpack_b_buffer + k0 ), a_reg );
|
||||
}
|
||||
dim_t k_left = (KC - k0);
|
||||
|
||||
if( k_left > 0 )
|
||||
{
|
||||
bfloat16 buff[8] = {0};
|
||||
for( int i = 0; i < k_left; i++ )buff[i] = ( *( b + (k0 + i) ) );
|
||||
|
||||
a_reg = CVT_BF16_F32_SHIFT_AVX2( (__m128i)_mm_loadu_si128( \
|
||||
(const __m128i*)( buff ) ) );
|
||||
GET_STORE_MASK(k_left, store_mask);
|
||||
_mm256_maskstore_ps( ( unpack_b_buffer + k0 ), store_mask, a_reg );
|
||||
}
|
||||
}
|
||||
else if( cs_b == 1 )
|
||||
if( cs_b == 1 )
|
||||
{
|
||||
unpackb_nr64_bf16_f32_row_major( b, unpack_b_buffer, NC, KC, rs_b );
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user