Bug Fix in BF16 AVX2 conversion path (#236)

- In the current implementation of bf16 to f32 conversion for packed data
 we handle both GEMM and GEMV conditions in the same function separated
 with conditions.
 - But, when n = (NC+1) the function would execute GEMV conversion logic
 and write back the data inaccurately leading to accuracy issues.
 - Hence, modified the convert function and reorder functions to have
 separate conversion logic to make it cleaner and avoid confusions.
 -  Also, updated the API calls to adhere to the changes appropriately.

[AMD-Internal: CPUPL-7540]
This commit is contained in:
V, Varsha
2025-10-17 15:38:02 +05:30
committed by GitHub
parent 0ce45e3147
commit fecb1aa7a5
4 changed files with 309 additions and 274 deletions

View File

@@ -891,21 +891,18 @@ LPGEMV_AVX2(bfloat16, bfloat16, float, bf16bf16f32of32)
{
/* For n = 1 case, a re-ordered matrix would be stored contigously
in memeory and hence need to be accessed likewise for conversion.*/
unpackb_nr64_bf16_f32
unpackb_nr64_bf16_f32_gemv
(
b, cvt_b_buffer_bf16_f32,
k, 1,
rs_b, cs_b, TRUE
b, cvt_b_buffer_bf16_f32, k
);
}
else
{
cvt_bf16_f32
// Direct call to optimized GEMV conversion (K=1, contiguous output)
cvt_bf16_f32_gemv_row_major
(
cvt_b_buffer_bf16_f32 ,
b, rs_b, cs_b,
k, 1,
rs_b_use, cs_b_use
cvt_b_buffer_bf16_f32,
b, rs_b, k
);
}
b_use = cvt_b_buffer_bf16_f32;
@@ -1084,7 +1081,7 @@ LPGEMV_AVX2(bfloat16, bfloat16, float, bf16bf16f32of32)
( jc_cur_loop_rem * kc0_updated) + ( n_sub_updated * pc )),
( b_unreorder + (nc0 * pc)),
kc0, nc0 ,
rs_b_use, cs_b_use, FALSE
rs_b_use, cs_b_use
);
b_use = b_unreorder;
}
@@ -1431,7 +1428,7 @@ LPGEMM_5LOOP_AVX2(bfloat16,bfloat16,float,bf16bf16f32of32)
( ( jc_cur_loop_rem + jc_packb_start ) *
kc0_updated ) ),( b_unreorder + jc_packb_start),
kc0, ( jc_packb_end - jc_packb_start ) ,
rs_b_use, cs_b_use, FALSE
rs_b_use, cs_b_use
);
}

View File

@@ -195,8 +195,7 @@ void unpackb_nr64_bf16_f32
const dim_t KC,
const dim_t NC,
dim_t rs_b,
dim_t cs_b,
bool is_n_one
dim_t cs_b
);
void cvt_bf16_f32(
@@ -210,4 +209,21 @@ void cvt_bf16_f32(
const dim_t cs_p
);
// Optimized GEMV conversion for true K=1 matrices with contiguous output
void cvt_bf16_f32_gemv_row_major(
float* cvt_buffer,
const bfloat16* a,
const dim_t rs_a,
const dim_t MC
);
// Optimized GEMV unpacking for true N=1 reordered matrices (contiguous storage)
void
unpackb_nr64_bf16_f32_gemv(
const bfloat16* b,
float* unpack_b_buffer,
const dim_t KC
);
#endif //BLIS_GEMM_BF16_PACKB