Much better

This commit is contained in:
Iwan Kawrakow
2025-02-28 16:53:02 +02:00
parent efc757c95c
commit 285b97b6bb
3 changed files with 80 additions and 53 deletions

View File

@@ -8785,7 +8785,8 @@ static struct ggml_tensor * llm_build_kqv(
0);
cb(v, "v", il);
if (cparams.attn_max_batch == 0 || cparams.attn_max_batch >= k->ne[1] || q->ne[2] == 1) {
auto kq_size = k->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024);
if (cparams.attn_max_batch == 0 || cparams.attn_max_batch >= kq_size || k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2]) {
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
@@ -8834,13 +8835,21 @@ static struct ggml_tensor * llm_build_kqv(
cb(cur, "kqv_merged_cont", il);
}
else {
// For now we will not support this option if k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2];
GGML_ASSERT(k->ne[2] == v->ne[2] && k->ne[2] == q->ne[2]);
int n_step = (kq_size + cparams.attn_max_batch - 1)/cparams.attn_max_batch;
n_step = std::min(n_step, int(k->ne[2]));
int n_per_step = (q->ne[2] + n_step - 1)/n_step;
auto r2k = q->ne[2] / k->ne[2];
auto r2v = q->ne[2] / n_head_kv;
auto r2v = q->ne[2] / v->ne[2];
n_step = q->ne[2];
n_per_step = 1;
ggml_tensor * kqv;
for (int i12 = 0; i12 < q->ne[2]; ++i12) {
for (int i12 = 0; i12 < q->ne[2]; i12 += n_per_step) {
int this_ne12 = i12 + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i12;
int i02 = i12/r2k;
auto k_i = ggml_view_2d(ctx, k, k->ne[0], k->ne[1], k->nb[1], k->nb[2]*i02);
auto q_i = ggml_view_2d(ctx, q, q->ne[0], q->ne[1], q->nb[1], q->nb[2]*i12);
auto k_i = ggml_view_3d(ctx, k, k->ne[0], k->ne[1], this_ne12, k->nb[1], k->nb[2], k->nb[2]*i02);
auto q_i = ggml_view_3d(ctx, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i12);
auto kq_i = ggml_mul_mat(ctx, k_i, q_i);
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
ggml_mul_mat_set_prec(kq_i, GGML_PREC_F32);
@@ -8855,7 +8864,7 @@ static struct ggml_tensor * llm_build_kqv(
kq_i = ggml_soft_max_ext(ctx, kq_i, kq_mask, kq_scale, hparams.f_max_alibi_bias);
}
i02 = i12 / r2v;
auto v_i = ggml_view_2d(ctx, v, v->ne[0], v->ne[1], v->nb[1], v->nb[2]*i02);
auto v_i = ggml_view_3d(ctx, v, v->ne[0], v->ne[1], this_ne12, v->nb[1], v->nb[2], v->nb[2]*i02);
auto kqv_i = ggml_mul_mat(ctx, v_i, kq_i);
if (i12 == 0) {
kqv = kqv_i;
@@ -13625,7 +13634,8 @@ struct llm_build_context {
}
ggml_tensor * kqv_compressed;
if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kv_cache->ne[1]) {
auto kq_size = kv_cache->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); // K*Q in MiB
if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kq_size) {
if (!pp_opt) {
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
cb(q, "q_perm", il);
@@ -13634,9 +13644,6 @@ struct llm_build_context {
ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
cb(kq, "kq", il);
//printf("kq (%ld x %ld x %ld x %ld) = kv_cache (%ld x %ld x %ld x %ld) * q (%ld x %ld x %ld x %ld)\n", kq->ne[0], kq->ne[1], kq->ne[2], kq->ne[3],
// kv_cache->ne[0], kv_cache->ne[1], kv_cache->ne[2], kv_cache->ne[3], q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
if (!pp_opt) {
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
cb(kq, "kq_perm", il);
@@ -13653,10 +13660,6 @@ struct llm_build_context {
kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
cb(kqv_compressed, "kqv_compressed", il);
//printf("kqv (%ld x %ld x %ld x %ld) = kv_cache_trans (%ld x %ld x %ld x %ld) * kq (%ld x %ld x %ld x %ld)\n",
// kqv_compressed->ne[0], kqv_compressed->ne[1], kqv_compressed->ne[2], kqv_compressed->ne[3],
// kv_cache_trans->ne[0], kv_cache_trans->ne[1], kv_cache_trans->ne[2], kv_cache_trans->ne[3], kq->ne[0], kq->ne[1], kq->ne[2], kq->ne[3]);
if (!pp_opt) {
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
cb(kqv_compressed, "kqv_compressed_perm", il);
@@ -13664,46 +13667,25 @@ struct llm_build_context {
} else {
int n_step = (q->ne[1] + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch;
int n_step = (kq_size + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch;
n_step = std::min(n_step, int(q->ne[2]));
int n_per_step = (q->ne[2] + n_step - 1)/n_step;
//kqv_compressed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, kv_cache_trans->ne[1], q->ne[1], q->ne[2]);
//printf("q->ne[1] = %ld -> need %d steps\n", q->ne[1], n_step);
//printf("Created kqv_compressed = %ld x %ld x %ld\n", kqv_compressed->ne[0], kqv_compressed->ne[1], kqv_compressed->ne[2]);
//printf("kq size would be %ld MiB -> splitting kqv computation into %d steps\n", kq_size, n_step);
for (int i_head = 0; i_head < q->ne[2]; ++i_head) {
ggml_tensor * q_i = ggml_view_2d(ctx0, q, q->ne[0], q->ne[1], q->nb[1], q->nb[2]*i_head);
for (int i_head = 0; i_head < q->ne[2]; i_head += n_per_step) {
int this_ne12 = i_head + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i_head;
ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i_head);
ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i);
kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i);
if (i_head == 0) {
kqv_compressed = ggml_view_3d(ctx0, kqv_i, kqv_i->ne[0], kqv_i->ne[1], 1, kqv_i->nb[1], kqv_i->nb[2], 0);
kqv_compressed = kqv_i;
} else {
kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2);
}
ggml_build_forward_expand(gf, kqv_compressed);
//ggml_tensor * kqv_compressed_i = ggml_view_1d(ctx0, kqv_compressed, ggml_nelements(kqv_i), kqv_compressed->nb[2]*i_head);
//ggml_tensor * head_i = ggml_cpy(ctx0, kqv_i, kqv_compressed_i);
//ggml_build_forward_expand(gf, head_i);
}
//for (int i_step = 0; i_step < n_step; ++i_step) {
// int i_start = i_step * lctx.cparams.attn_max_batch;
// int this_batch = i_start + lctx.cparams.attn_max_batch <= q->ne[1] ? lctx.cparams.attn_max_batch : q->ne[1] - i_start;
// ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], this_batch, q->ne[2], q->nb[1], q->nb[2], i_start*q->nb[1]);
// cb(q_i, "q_i", il);
// ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i);
// cb(kq_i, "kq_i", il);
// ggml_tensor * mask_i = ggml_view_2d(ctx0, KQ_mask, KQ_mask->ne[0], this_batch, KQ_mask->nb[1], i_start*KQ_mask->nb[1]);
// kq_i = ggml_soft_max_ext(ctx0, kq_i, mask_i, kq_scale, hparams.f_max_alibi_bias);
// cb(kq_i, "kq_i_softmwax", il);
// ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i);
// cb(kqv_i, "kqv_i", il);
// ggml_tensor * kqv_compressed_i = ggml_view_3d(ctx0, kqv_compressed, kqv_compressed->ne[0], this_batch, kqv_compressed->ne[2],
// kqv_compressed->nb[1], kqv_compressed->nb[2], i_start*kqv_compressed->nb[1]);
// printf("step %d (%d tokens): kqv_i = %ld x %ld x %ld, kqv_compressed_i = %ld x %ld x %ld\n", i_step, this_batch,
// kqv_i->ne[0], kqv_i->ne[1], kqv_i->ne[2], kqv_compressed_i->ne[0], kqv_compressed_i->ne[1], kqv_compressed_i->ne[2]);
// ggml_cpy(ctx0, kqv_i, kqv_compressed_i);
//}
cb(kqv_compressed, "kqv_compressed", il);
}