WIP: plugging into ggml_compute_forward_flash_attn_ext_f16

This commit is contained in:
Iwan Kawrakow
2024-08-25 19:08:28 +03:00
parent cfee1b68ec
commit 585aa2bee3

View File

@@ -6371,6 +6371,35 @@ void iqk_flash_helper(int nq, // number of elements in q
softmax_extended(nk, qk, qk, scale, slope, (const char *)mask, true); softmax_extended(nk, qk, qk, scale, slope, (const char *)mask, true);
} }
namespace {
inline void accumulate(int n, float * saux, float smax, float& M, float& S, __m512 * acc, const char * v, int stride_v) {
if (smax > M) {
if (M > -INFINITY) {
float ms = expf(M - smax);
auto vms = _mm512_set1_ps(ms);
for (int i = 0; i < 8; ++i) acc[i] = _mm512_mul_ps(vms, acc[i]);
S *= ms;
} else {
for (int i = 0; i < 8; ++i) acc[i] = _mm512_setzero_ps();
S = 0;
}
M = smax;
}
//auto vs_all = v_expf(_mm256_sub_ps(_mm256_loadu_ps(saux), _mm256_set1_ps(M)));
//_mm256_storeu_ps(saux, vs_all);
auto vs_all = v_expf(_mm512_sub_ps(_mm512_loadu_ps(saux), _mm512_set1_ps(M)));
_mm512_storeu_ps(saux, vs_all);
for (int j = 0; j < n; ++j) {
S += saux[j];
auto vs = _mm512_set1_ps(saux[j]);
auto vr = (const ggml_half *)(v + stride_v*j);
for (int i = 0; i < 8; ++i) {
acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]);
}
}
}
}
void iqk_flash_helper_2(int nq, // number of elements in q void iqk_flash_helper_2(int nq, // number of elements in q
int nk, // number of rows in k int nk, // number of rows in k
int stride_k, // distance between rows in k (in bytes) int stride_k, // distance between rows in k (in bytes)
@@ -6387,22 +6416,91 @@ void iqk_flash_helper_2(int nq, // number of elements in q
//GGML_ASSERT(nq / 16 <= 16); //GGML_ASSERT(nq / 16 <= 16);
if (nq == 128) { if (nq == 128) {
constexpr int kNchunk = 16;
const ggml_half * mp = mask ? (const ggml_half *)mask : nullptr; const ggml_half * mp = mask ? (const ggml_half *)mask : nullptr;
__m512 vq[8]; __m512 vq[8];
__m512 acc[8] = {}; __m512 acc[8]; // = {};
for (int i = 0; i < 8; ++i) vq[i] = _mm512_loadu_ps(q + 16*i); for (int i = 0; i < 8; ++i) vq[i] = _mm512_loadu_ps(q + 16*i);
float M = -INFINITY; float M = -INFINITY;
float S = 0; float S = 0;
float saux[kNchunk];
float smax = -INFINITY;
int last_ik = 0;
int ik = 0; int ik = 0;
//if (nk >= 8) { for (; ik < nk; ++ik) {
if (ik - last_ik == kNchunk) {
accumulate(kNchunk, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v);
//if (smax > M) {
// if (M > -INFINITY) {
// float ms = expf(M - smax);
// auto vms = _mm512_set1_ps(ms);
// for (int i = 0; i < 8; ++i) acc[i] = _mm512_mul_ps(vms, acc[i]);
// S *= ms;
// }
// M = smax;
//}
//auto vs_all = v_expf(_mm256_sub_ps(_mm256_loadu_ps(saux), _mm256_set1_ps(M)));
//_mm256_storeu_ps(saux, vs_all);
//for (int j = 0; j < 8; ++j) {
// S += saux[j];
// auto vs = _mm512_set1_ps(saux[j]);
// auto vr = (const ggml_half *)((const char *)v + stride_v*(last_ik + j));
// for (int i = 0; i < 8; ++i) {
// acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]);
// }
//}
last_ik = ik;
smax = -INFINITY;
}
const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ik]) : 0.0f;
if (mv == -INFINITY) break; //continue;
const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik);
auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr)));
for (int i = 1; i < 8; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum);
float s = scale * _mm512_reduce_add_ps(sum) + mv;
saux[ik - last_ik] = s;
smax = std::max(smax, s);
}
if (ik > last_ik) {
accumulate(ik - last_ik, saux, smax, M, S, acc, (const char *)v + stride_v*last_ik, stride_v);
//if (smax > M) {
// if (M > -INFINITY) {
// float ms = expf(M - smax);
// auto vms = _mm512_set1_ps(ms);
// for (int i = 0; i < 8; ++i) acc[i] = _mm512_mul_ps(vms, acc[i]);
// S *= ms;
// }
// M = smax;
//}
//for (int j = last_ik; j < ik; ++j) {
// float s = expf(saux[j - last_ik] - M);
// S += s;
// auto vs = _mm512_set1_ps(s);
// auto vr = (const ggml_half *)((const char *)v + stride_v*j);
// for (int i = 0; i < 8; ++i) {
// acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]);
// }
//}
}
if (S > 0) {
auto norm = _mm512_set1_ps(1/S);
for (int i = 0; i < 8; ++i) _mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, acc[i]));
} else {
printf("Oops: S = %g. ik = %d, last_ik = %d, nk = %d, M = %g\n", S, ik, last_ik, nk, M);
GGML_ASSERT(false);
std::memset(qkv, 0, 128*sizeof(float));
}
return;
//int ik = 0;
//if (false && nk >= 8) {
// bool finished = false;
// float s8[8]; // float s8[8];
// for (int ik8 = 0; ik8 < nk/8; ++ik8) { // for (int ik8 = 0; ik8 < nk/8; ++ik8) {
// float smax = -INFINITY; // float smax = -INFINITY;
// for (int j = 0; j < 8; ++j) { // int j = 0;
// for (; j < 8; ++j) {
// const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[8*ik+j]) : 0.0f; // const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[8*ik+j]) : 0.0f;
// if (mv == -INFINITY) { // if (mv == -INFINITY) break;
// s8[j] = -INFINITY; continue;
// }
// const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*(8*ik8 + j)); // const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*(8*ik8 + j));
// auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr))); // auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr)));
// for (int i = 1; i < 8; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum); // for (int i = 1; i < 8; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum);
@@ -6410,6 +6508,7 @@ void iqk_flash_helper_2(int nq, // number of elements in q
// s8[j] = s; // s8[j] = s;
// smax = std::max(smax, s); // smax = std::max(smax, s);
// } // }
// if (smax == -INFINITY) { finished = true; break; }
// if (smax > M) { // if (smax > M) {
// if (M > -INFINITY) { // if (M > -INFINITY) {
// float ms = expf(M - smax); // float ms = expf(M - smax);
@@ -6418,75 +6517,63 @@ void iqk_flash_helper_2(int nq, // number of elements in q
// S *= ms; // S *= ms;
// } else { // } else {
// for (int i = 0; i < 8; ++i) acc[i] = _mm512_setzero_ps(); // for (int i = 0; i < 8; ++i) acc[i] = _mm512_setzero_ps();
// S = 0;
// } // }
// M = smax; // M = smax;
// } // }
// if (smax == -INFINITY) break;
// _mm256_storeu_ps(s8, v_expf(_mm256_sub_ps(_mm256_loadu_ps(s8), _mm256_set1_ps(M)))); // _mm256_storeu_ps(s8, v_expf(_mm256_sub_ps(_mm256_loadu_ps(s8), _mm256_set1_ps(M))));
// for (int j = 0; j < 8; ++j) { // for (int l = 0; l < j; ++l) {
// if (s8[j] <= 0.0f) continue; // if (s8[l] <= 0.0f) continue;
// const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*(8*ik8+j)); // const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*(8*ik8+l));
// auto vs = _mm512_set1_ps(s8[j]); // auto vs = _mm512_set1_ps(s8[l]);
// S += s8[j]; // S += s8[l];
// for (int i = 0; i < 8; ++i) { // for (int i = 0; i < 8; ++i) {
// acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]); // acc[i] = _mm512_fmadd_ps(vs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]);
// } // }
// } // }
// } // }
// ik = 8*(nk/8); // ik = finished ? nk : 8*(nk/8);
//} //}
//int last_i = -1; //int last_ik = 0;
for (; ik < nk; ++ik) { //for (; ik < nk; ++ik) {
const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ik]) : 0.0f; // const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ik]) : 0.0f;
if (mv == -INFINITY) break; //continue; // if (mv == -INFINITY) break; //continue;
//last_i = ik; // const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik);
const ggml_half * kr = (const ggml_half *)((const char *)k + stride_k*ik); // const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*ik);
const ggml_half * vr = (const ggml_half *)((const char *)v + stride_v*ik); // auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr)));
//auto sum1 = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr+0))); // for (int i = 1; i < 8; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum);
//auto sum2 = _mm512_mul_ps(vq[1], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr+1))); // float s = scale * _mm512_reduce_add_ps(sum) + mv;
//for (int i = 2; i < 8; i += 2) { // if (s > M) {
// sum1 = _mm512_fmadd_ps(vq[i+0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i + 0)), sum1); // if (M > -INFINITY) {
// sum2 = _mm512_fmadd_ps(vq[i+1], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i + 1)), sum2); // float ms = expf(M - s);
//} // auto vms = _mm512_set1_ps(ms);
//float s = scale * _mm512_reduce_add_ps(_mm512_add_ps(sum1, sum2)) + mv; // for (int i = 0; i < 8; ++i) {
auto sum = _mm512_mul_ps(vq[0], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr))); // acc[i] = _mm512_fmadd_ps(vms, acc[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)));
for (int i = 1; i < 8; ++i) sum = _mm512_fmadd_ps(vq[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)kr + i)), sum); // }
float s = scale * _mm512_reduce_add_ps(sum) + mv; // S = ms*S + 1;
if (s > M) { // } else {
if (M > -INFINITY) { // for (int i = 0; i < 8; ++i) {
//if (M - s > -20.f) { // acc[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i));
float ms = expf(M - s); // }
auto vms = _mm512_set1_ps(ms); // S = 1;
for (int i = 0; i < 8; ++i) { // }
acc[i] = _mm512_fmadd_ps(vms, acc[i], _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i))); // M = s;
} // } else {
S = ms*S + 1; // float vs = expf(s - M);
} else { // auto vvs = _mm512_set1_ps(vs);
for (int i = 0; i < 8; ++i) { // for (int i = 0; i < 8; ++i) {
acc[i] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)); // acc[i] = _mm512_fmadd_ps(vvs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]);
} // }
S = 1; // S += vs;
} // }
M = s; //}
} else { //if (S > 0) {
//if (s - M > -20.f) { // auto norm = _mm512_set1_ps(1/S);
float vs = expf(s - M); // for (int i = 0; i < 8; ++i) _mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, acc[i]));
auto vvs = _mm512_set1_ps(vs); //} else {
for (int i = 0; i < 8; ++i) {
acc[i] = _mm512_fmadd_ps(vvs, _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)vr + i)), acc[i]);
}
S += vs;
//}
}
}
//if (last_i < 0) {
// std::memset(qkv, 0, 128*sizeof(float)); // std::memset(qkv, 0, 128*sizeof(float));
// return;
//} //}
//printf("%s: nk = %d, last_i = %d\n", __func__, nk, last_i); //return;
auto norm = _mm512_set1_ps(1/S);
for (int i = 0; i < 8; ++i) _mm512_storeu_ps(qkv + 16*i, _mm512_mul_ps(norm, acc[i]));
return;
} }
DataInfo info{qk, (const char*)q, 0, size_t(stride_k), 0, 1, nullptr, 0}; DataInfo info{qk, (const char*)q, 0, size_t(stride_k), 0, 1, nullptr, 0};