Fix the strange FA behavior with odd/even batch sizes

This commit is contained in:
Iwan Kawrakow
2025-01-12 16:12:05 +02:00
parent c19404bcda
commit 983e86805e

View File

@@ -17362,12 +17362,16 @@ static void ggml_compute_forward_flash_attn_ext_f16(
if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
int64_t work_per_slice = D*nek1*neq1;
int ntg = 1;
//
// When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix
// But we also want each thread to process the same amount of rows, so neq1 must be a multiple of
// the number of threads processing the (iq2, iq3) matrix.
//
if (neq1 >= 8*nth) {
if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
if ((neq2*neq3)%(nth/ntg) == 0) {
//if (ith == 0) printf("%s: D = %d, neq2 = %d, neq1 = %d, nek1 = %d, ntg = %d, neq1/ntg = %d\n", __func__,
// (int)D, (int)neq2, (int)neq1, (int)nek1, ntg, (int)(neq1/ntg));
}
int counter = 0;
for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
@@ -17385,7 +17389,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
}
}
return;
}
IQK_Flash_Attn_NotAvailable:;
}