mirror of
https://github.com/amd/blis.git
synced 2026-04-20 15:48:50 +00:00
Bug Fix in BF16 AVX2 conversion path (#236)
- In the current implementation of bf16 to f32 conversion for packed data we handle both GEMM and GEMV conditions in the same function separated with conditions. - But, when n = (NC+1) the function would execute GEMV conversion logic and write back the data inaccurately leading to accuracy issues. - Hence, modified the convert function and reorder functions to have separate conversion logic to make it cleaner and avoid confusions. - Also, updated the API calls to adhere to the changes appropriately. [AMD-Internal: CPUPL-7540]
This commit is contained in:
@@ -891,21 +891,18 @@ LPGEMV_AVX2(bfloat16, bfloat16, float, bf16bf16f32of32)
|
||||
{
|
||||
/* For n = 1 case, a re-ordered matrix would be stored contigously
|
||||
in memeory and hence need to be accessed likewise for conversion.*/
|
||||
unpackb_nr64_bf16_f32
|
||||
unpackb_nr64_bf16_f32_gemv
|
||||
(
|
||||
b, cvt_b_buffer_bf16_f32,
|
||||
k, 1,
|
||||
rs_b, cs_b, TRUE
|
||||
b, cvt_b_buffer_bf16_f32, k
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
cvt_bf16_f32
|
||||
// Direct call to optimized GEMV conversion (K=1, contiguous output)
|
||||
cvt_bf16_f32_gemv_row_major
|
||||
(
|
||||
cvt_b_buffer_bf16_f32 ,
|
||||
b, rs_b, cs_b,
|
||||
k, 1,
|
||||
rs_b_use, cs_b_use
|
||||
cvt_b_buffer_bf16_f32,
|
||||
b, rs_b, k
|
||||
);
|
||||
}
|
||||
b_use = cvt_b_buffer_bf16_f32;
|
||||
@@ -1084,7 +1081,7 @@ LPGEMV_AVX2(bfloat16, bfloat16, float, bf16bf16f32of32)
|
||||
( jc_cur_loop_rem * kc0_updated) + ( n_sub_updated * pc )),
|
||||
( b_unreorder + (nc0 * pc)),
|
||||
kc0, nc0 ,
|
||||
rs_b_use, cs_b_use, FALSE
|
||||
rs_b_use, cs_b_use
|
||||
);
|
||||
b_use = b_unreorder;
|
||||
}
|
||||
@@ -1431,7 +1428,7 @@ LPGEMM_5LOOP_AVX2(bfloat16,bfloat16,float,bf16bf16f32of32)
|
||||
( ( jc_cur_loop_rem + jc_packb_start ) *
|
||||
kc0_updated ) ),( b_unreorder + jc_packb_start),
|
||||
kc0, ( jc_packb_end - jc_packb_start ) ,
|
||||
rs_b_use, cs_b_use, FALSE
|
||||
rs_b_use, cs_b_use
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -195,8 +195,7 @@ void unpackb_nr64_bf16_f32
|
||||
const dim_t KC,
|
||||
const dim_t NC,
|
||||
dim_t rs_b,
|
||||
dim_t cs_b,
|
||||
bool is_n_one
|
||||
dim_t cs_b
|
||||
);
|
||||
|
||||
void cvt_bf16_f32(
|
||||
@@ -210,4 +209,21 @@ void cvt_bf16_f32(
|
||||
const dim_t cs_p
|
||||
);
|
||||
|
||||
// Optimized GEMV conversion for true K=1 matrices with contiguous output
|
||||
void cvt_bf16_f32_gemv_row_major(
|
||||
float* cvt_buffer,
|
||||
const bfloat16* a,
|
||||
const dim_t rs_a,
|
||||
const dim_t MC
|
||||
);
|
||||
|
||||
// Optimized GEMV unpacking for true N=1 reordered matrices (contiguous storage)
|
||||
void
|
||||
unpackb_nr64_bf16_f32_gemv(
|
||||
const bfloat16* b,
|
||||
float* unpack_b_buffer,
|
||||
const dim_t KC
|
||||
);
|
||||
|
||||
|
||||
#endif //BLIS_GEMM_BF16_PACKB
|
||||
|
||||
Reference in New Issue
Block a user