mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-07 06:50:09 +00:00
Fix the strange FA behavior with odd/even batch sizes (#171)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -17362,30 +17362,33 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
|
||||
int64_t work_per_slice = D*nek1*neq1;
|
||||
int ntg = 1;
|
||||
if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
|
||||
else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
|
||||
else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
|
||||
if ((neq2*neq3)%(nth/ntg) == 0) {
|
||||
//if (ith == 0) printf("%s: D = %d, neq2 = %d, neq1 = %d, nek1 = %d, ntg = %d, neq1/ntg = %d\n", __func__,
|
||||
// (int)D, (int)neq2, (int)neq1, (int)nek1, ntg, (int)(neq1/ntg));
|
||||
int counter = 0;
|
||||
for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
|
||||
for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
|
||||
if (counter++ % (nth/ntg) == ith/ntg) {
|
||||
int iq1 = (ith%ntg)*neq1/ntg;
|
||||
if (!iqk_flash_attn_noalibi(k->type, v->type,
|
||||
D, neq1/ntg, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
|
||||
(const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]),
|
||||
(const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
|
||||
(const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),
|
||||
(const void *)((const char *)mask->data + iq1*mask->nb[1]),
|
||||
scale, softcap,
|
||||
(float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable;
|
||||
}
|
||||
//
|
||||
// When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix
|
||||
// But we also want each thread to process the same amount of rows, so neq1 must be a multiple of
|
||||
// the number of threads processing the (iq2, iq3) matrix.
|
||||
//
|
||||
if (neq1 >= 8*nth) {
|
||||
if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
|
||||
else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
|
||||
else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
|
||||
}
|
||||
int counter = 0;
|
||||
for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
|
||||
for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
|
||||
if (counter++ % (nth/ntg) == ith/ntg) {
|
||||
int iq1 = (ith%ntg)*neq1/ntg;
|
||||
if (!iqk_flash_attn_noalibi(k->type, v->type,
|
||||
D, neq1/ntg, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
|
||||
(const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]),
|
||||
(const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
|
||||
(const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),
|
||||
(const void *)((const char *)mask->data + iq1*mask->nb[1]),
|
||||
scale, softcap,
|
||||
(float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
return;
|
||||
IQK_Flash_Attn_NotAvailable:;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user