This works on CUDA, but (#247)

PP speed is great, almost on par with standard FA.
But TG speed is pathetic. The strangest thing is that
the slowdown is not due to FA, but due to the ffn_gate_exps
gemm, which somehow becomes very slow. WTF?

As I'm unable the resolve the slow ffn_gate_exps GEMM mystery,
for now TG goes via mla=2, PP is via FA.
Also discovered the ggml_cast op, so we don't need the aux
tensors that I had added to the KV cache.

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-03-09 16:53:55 +02:00
committed by GitHub
parent 81748fb55e
commit b096a5de7a

View File

@@ -2691,9 +2691,6 @@ struct llama_kv_cache {
// DeepSeek MLA
std::vector<struct ggml_tensor *> kv_l;
std::vector<struct ggml_tensor *> kvt_l;
ggml_tensor * kv_aux_f32 = nullptr;
ggml_tensor * k_aux = nullptr;
ggml_tensor * v_aux = nullptr;
std::vector<struct ggml_context *> ctxs;
std::vector<ggml_backend_buffer_t> bufs;
@@ -3209,28 +3206,6 @@ static bool llama_kv_cache_init(
ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size);
ggml_format_name(kv, "cache_kv_l%d", i);
cache.kv_l.push_back(kv);
if (cparams.mla_attn > 1 && cache.kv_aux_f32 == nullptr) {
cache.kv_aux_f32 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32,
kv_lora_rank + n_embd_head_qk_rope, kv_size);
//(n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head, kv_size);
ggml_format_name(cache.kv_aux_f32, "kv_aux_f32%d", 0);
cache.k_aux = ggml_new_tensor_3d(ctx, cache.type_k, hparams.n_embd_head_k, n_head, kv_size);
ggml_format_name(cache.k_aux, "k_aux%d", 0);
cache.v_aux = ggml_new_tensor_3d(ctx, cache.type_k, hparams.n_embd_head_v, n_head, kv_size);
ggml_format_name(cache.v_aux, "v_aux%d", 0);
//cache.kv_aux_2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32,
// (hparams.n_embd_head_k + hparams.n_embd_head_v)*n_head, kv_size);
//ggml_format_name(cache.kv_aux, "kv_aux%d", 0);
//ggml_format_name(cache.kv_aux_2, "kv_aux%d", 2);
LLAMA_LOG_INFO("%s: allocated kv auxilary tensors as %ld x %ld, %ld x %ld x %ld, %ld x %ld x %ld\n", __func__,
cache.kv_aux_f32->ne[0], cache.kv_aux_f32->ne[1],
cache.k_aux->ne[0], cache.k_aux->ne[1], cache.k_aux->ne[2],
cache.v_aux->ne[0], cache.v_aux->ne[1], cache.v_aux->ne[2]);
}
} else {
auto kv_type = cparams.mla_attn == 1 ? cache.type_k : cache.type_v;
ggml_tensor * kv = ggml_new_tensor_2d(ctx, kv_type, kv_lora_rank + n_embd_head_qk_rope, kv_size);
@@ -13659,215 +13634,165 @@ struct llm_build_context {
// provide ops on (almost) anything other than f32. In this case, the cache will be the second operand to a matrix
// multiplication, which *must* be f32.
auto kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_kv, kv_self.kv_l[il]->nb[1], 0);
auto kv_cache_view_f32 = ggml_view_2d(ctx0, kv_self.kv_aux_f32, kv_self.kv_aux_f32->ne[0], n_kv, kv_self.kv_aux_f32->nb[1], 0);
kv_cache_view_f32 = ggml_cpy(ctx0, kv_cache_view, kv_cache_view_f32);
auto kv_cache_view_f32 = ggml_cast(ctx0, kv_cache_view, GGML_TYPE_F32);
cb(kv_cache_view_f32, "kv_cache_view_f32", il);
// The no- and rorational position encoding portions of the KV cache
// The no- and rotational position encoding portions of the KV cache
auto kv_cache_nope = ggml_view_2d(ctx0, kv_cache_view_f32, kv_lora_rank, n_kv, kv_cache_view_f32->nb[1], 0);
auto kv_cache_rope = ggml_view_3d(ctx0, kv_cache_view_f32, n_embd_head_qk_rope, 1, n_kv,
kv_cache_view_f32->nb[1], kv_cache_view_f32->nb[1], ggml_row_size(kv_cache_view_f32->type, kv_lora_rank));
auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope);
//// split into {n_head * n_embd_head_qk_nope, n_tokens}
//struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
// ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
// ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
// 0);
//cb(k_nope, "k_nope", il);
//// and {n_head * n_embd_head_v, n_tokens}
//struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
// ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
// ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
// ggml_row_size(kv->type, (n_embd_head_qk_nope)));
//cb(v_states, "v_states", il);
cb(kv_f32, "kv_f32", il);
auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head,
ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0);
cb(k_nope_f32, "k_nope_f32", il);
ggml_tensor repeater;
repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_head; repeater.ne[2] = n_kv; repeater.ne[3] = 1;
auto k_rope_f32 = ggml_permute(ctx0, ggml_repeat(ctx0, kv_cache_rope, &repeater), 0, 2, 1, 3);
cb(k_rope_f32, "k_rope_f32", il);
auto k_f32 = ggml_concat(ctx0, k_nope_f32, k_rope_f32, 0);
cb(k_f32, "k_f32", il);
auto k_row_size = ggml_row_size(kv_self.k_aux->type, k_f32->ne[0]);
auto k = ggml_view_3d(ctx0, kv_self.k_aux, k_f32->ne[0], k_f32->ne[1], k_f32->ne[2],
k_row_size, k_row_size*k_f32->ne[1], 0);
//kv_self.k_aux->nb[1], k_row_size, 0);
//k_row_size, kv_self.k_aux->nb[1], 0);
k = ggml_cpy(ctx0, k_f32, k);
auto k = ggml_cast(ctx0, k_f32, kv_self.kv_l[il]->type);
cb(k, "k", il);
auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head,
ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
ggml_row_size(kv_f32->type, n_embd_head_qk_nope));
cb(v_f32, "v_f32", il);
auto v_row_size = ggml_row_size(kv_self.v_aux->type, v_f32->ne[0]);
auto v = ggml_view_3d(ctx0, kv_self.v_aux, v_f32->ne[0], v_f32->ne[1], v_f32->ne[2],
v_row_size, v_row_size*v_f32->ne[1], 0);
//kv_self.v_aux->nb[1], v_row_size, 0);
//v_row_size, kv_self.v_aux->nb[1], 0);
v = ggml_cpy(ctx0, v_f32, v);
//auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_cache->nb[1], 0);
//auto kv_cache_rope = ggml_view_2d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, kv_cache->nb[1],
// ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank));
////kv_cache_rope = ggml_permute(ctx0, kv_cache_rope, 0, 2, 1, 3);
//auto kv_cache_nope_f32 = ggml_view_2d(ctx0, kv_self.kv_aux_f32, kv_lora_rank, n_kv, kv_self.kv_aux_f32->nb[1], 0);
//kv_cache_nope_f32 = ggml_cpy(ctx0, kv_cache_nope, kv_cache_nope_f32);
//auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope_f32);
////auto kv = ggml_new_tensor_2d(ctx0, kv_self.kv_l[il]->type, model.layers[il].wkv_b->ne[1], n_kv);
////auto kv = ggml_new_tensor_2d(ctx0, kv_self.kv_l[il]->type, kv_f32->ne[0], kv_f32->ne[1]);
//auto kv = ggml_view_2d(ctx0, kv_self.kv_aux, kv_self.kv_aux->ne[0], n_kv, kv_self.kv_aux->nb[1], 0);
//kv = ggml_cpy(ctx0, kv_f32, kv);
//auto k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_kv, n_head,
// ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
// ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0);
//ggml_tensor repeater;
//repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_head; repeater.ne[3] = 1;
//auto k = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, kv_cache_rope, &repeater), 0);
//auto v = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_kv, n_head,
// ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
// ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
// ggml_row_size(kv->type, n_embd_head_qk_nope));
auto v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type);
cb(v, "v", il);
auto q = ggml_concat(ctx0, q_nope, q_rope, 0);
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
cb(q, "q_concat", il);
ggml_build_forward_expand(gf, q);
kqv = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
if (q->ne[1] <= 8) {
ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
}
cb(kqv, "kqv", il);
cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens);
//kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
//cb(kqv_compressed, "kqv_compressed_perm", il);
}
else {
ggml_tensor * kqv_compressed;
ggml_tensor * kqv_compressed;
//printf("wkv_b: %ld x %ld x %ld kv_cache: %ld x %ld x %ld\n", model.layers[il].wkv_b->ne[0], model.layers[il].wkv_b->ne[1], model.layers[il].wkv_b->ne[2], kv_cache->ne[0], kv_cache->ne[1], kv_cache->ne[2]);
struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head,
ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope),
ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank)*n_embd_head_qk_nope, 0);
cb(wk_b, "wk_b", il);
struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head,
ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope),
ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank)*n_embd_head_qk_nope, 0);
cb(wk_b, "wk_b", il);
q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
cb(q_nope, "q_nope_perm", il);
q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
//if (q_nope->ne[1] <= 32) q_nope = ggml_cont(ctx0, q_nope);
cb(q_nope, "q_nope_perm", il);
struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope);
cb(q_nope2, "q_nope2", il);
struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope);
cb(q_nope2, "q_nope2", il);
//printf("q_nope2 (%ld x %ld x %ld) = wk_b (%ld x %ld x %ld) * q_nope (%ld x %ld x %ld)\n", q_nope2->ne[0], q_nope2->ne[1], q_nope2->ne[2],
// wk_b->ne[0], wk_b->ne[1], wk_b->ne[2], q_nope->ne[0], q_nope->ne[1], q_nope->ne[2]);
ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0);
cb(q, "q", il);
ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0);
cb(q, "q", il);
if (lctx.cparams.flash_attn) {
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
kv_lora_rank, n_kv,
ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0);
cb(kv_cache_lora, "kv_cache_lora", il);
//ggml_tensor * v = ggml_cont(ctx0, kv_cache_lora);
//kqv_compressed = ggml_flash_attn_ext(ctx0, q, kv_cache, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
kqv_compressed = ggml_flash_attn_ext(ctx0, q, kv_cache, kv_cache_lora, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
cb(kqv_compressed, "kqv_compressed", il);
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
cb(kqv_compressed, "kqv_compressed_perm", il);
}
else {
if (lctx.cparams.mla_attn > 1) {
if (lctx.cparams.flash_attn && lctx.cparams.mla_attn == 1) {
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
kv_lora_rank, n_kv,
ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0);
cb(kv_cache, "kv_cache_lora", il);
cb(kv_cache_lora, "kv_cache_lora", il);
kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora));
cb(kv_cache_trans, "kv_cache_trans", il);
}
auto kq_size = kv_cache->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); // K*Q in MiB
if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kq_size) {
if (!pp_opt) {
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
cb(q, "q_perm", il);
}
ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
cb(kq, "kq", il);
if (!pp_opt) {
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
cb(kq, "kq_perm", il);
}
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
if (!pp_opt) {
kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
cb(kq, "kq_soft_max_ext_perm", il);
}
kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
kqv_compressed = ggml_flash_attn_ext(ctx0, q, kv_cache, kv_cache_lora, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f);
cb(kqv_compressed, "kqv_compressed", il);
if (!pp_opt) {
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
cb(kqv_compressed, "kqv_compressed_perm", il);
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
cb(kqv_compressed, "kqv_compressed_perm", il);
}
else {
if (lctx.cparams.mla_attn > 1) {
ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il],
kv_lora_rank, n_kv,
ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0);
cb(kv_cache, "kv_cache_lora", il);
kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora));
cb(kv_cache_trans, "kv_cache_trans", il);
}
} else {
int n_step = (kq_size + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch;
n_step = std::min(n_step, int(q->ne[2]));
int n_per_step = (q->ne[2] + n_step - 1)/n_step;
//printf("kq size would be %ld MiB -> splitting kqv computation into %d steps\n", kq_size, n_step);
for (int i_head = 0; i_head < q->ne[2]; i_head += n_per_step) {
int this_ne12 = i_head + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i_head;
ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i_head);
ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i);
kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i);
if (i_head == 0) {
kqv_compressed = kqv_i;
} else {
kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2);
auto kq_size = kv_cache->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); // K*Q in MiB
if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kq_size) {
if (!pp_opt) {
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
cb(q, "q_perm", il);
}
ggml_build_forward_expand(gf, kqv_compressed);
ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q);
cb(kq, "kq", il);
if (!pp_opt) {
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
cb(kq, "kq_perm", il);
}
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);
if (!pp_opt) {
kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
cb(kq, "kq_soft_max_ext_perm", il);
}
kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
cb(kqv_compressed, "kqv_compressed", il);
if (!pp_opt) {
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
cb(kqv_compressed, "kqv_compressed_perm", il);
}
} else {
int n_step = (kq_size + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch;
n_step = std::min(n_step, int(q->ne[2]));
int n_per_step = (q->ne[2] + n_step - 1)/n_step;
for (int i_head = 0; i_head < q->ne[2]; i_head += n_per_step) {
int this_ne12 = i_head + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i_head;
ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i_head);
ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i);
kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i);
if (i_head == 0) {
kqv_compressed = kqv_i;
} else {
kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2);
}
ggml_build_forward_expand(gf, kqv_compressed);
}
cb(kqv_compressed, "kqv_compressed", il);
}
cb(kqv_compressed, "kqv_compressed", il);
}
}
struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head,
ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank),
ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank)*n_embd_head_v, 0);
cb(wv_b, "wv_b", il);
struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head,
ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank),
ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank)*n_embd_head_v, 0);
cb(wv_b, "wv_b", il);
kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed);
cb(kqv, "kqv", il);
kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed);
cb(kqv, "kqv", il);
kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3));
cb(kqv, "kqv_perm", il);
kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3));
cb(kqv, "kqv_perm", il);
cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0);
cb(cur, "kqv_2d", il);
cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0);
cb(cur, "kqv_2d", il);
}
ggml_build_forward_expand(gf, cur);