Hopefully take care of missing V cache (MLA)

This commit is contained in:
Iwan Kawrakow
2025-05-28 14:16:01 +03:00
parent edd049b0d3
commit ac27355e3b

View File

@@ -10347,34 +10347,39 @@ struct llm_build_context {
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id));
ggml_tensor * view_v_src;
ggml_tensor * view_v_dst;
ggml_tensor * view_v_src = nullptr;
ggml_tensor * view_v_dst = nullptr;
if (flash_attn) {
// NOTE: the V cache is not transposed when using flash attention
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
n_embd_v_gqa, nm,
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i));
if (kv_self.v_l.size() > il) {
// Note: with MLA the V cache may not be present.
if (flash_attn) {
// NOTE: the V cache is not transposed when using flash attention
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
n_embd_v_gqa, nm,
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i));
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
n_embd_v_gqa, nm,
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id));
} else {
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
nm, n_embd_v_gqa,
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
ggml_row_size(kv_self.v_l[il]->type, i));
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
n_embd_v_gqa, nm,
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id));
} else {
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
nm, n_embd_v_gqa,
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
ggml_row_size(kv_self.v_l[il]->type, i));
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
nm, n_embd_v_gqa,
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
ggml_row_size(kv_self.v_l[il]->type, id));
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
nm, n_embd_v_gqa,
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
ggml_row_size(kv_self.v_l[il]->type, id));
}
}
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
if (view_v_src && view_v_dst) {
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
}
}
i += nm - 1;
@@ -21445,7 +21450,10 @@ struct llama_data_write {
const struct llama_kv_cache & kv_self = ctx->kv_self;
const struct llama_hparams & hparams = ctx->model.hparams;
const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
// Misuse v_trans: 0 -> not transposed V cache
// 1 -> transposed V cache
// 2 -> no V cache (as it maybe the case with MLA)
const uint32_t v_trans = kv_self.v_l.empty() ? 2 : kv_self.v_trans ? 1 : 0;
const uint32_t n_layer = hparams.n_layer;
write(&v_trans, sizeof(v_trans));
@@ -21474,7 +21482,7 @@ struct llama_data_write {
}
}
if (!kv_self.v_trans) {
if (kv_self.v_trans == 0) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
@@ -21493,7 +21501,8 @@ struct llama_data_write {
write_tensor_data(kv_self.v_l[il], range.first * v_size_row, buf_size);
}
}
} else {
}
else if (v_trans == 1) {
// When v is transposed, we also need the element size and get the element ranges from each row
const uint32_t kv_size = kv_self.size;
for (uint32_t il = 0; il < n_layer; ++il) {
@@ -21785,7 +21794,7 @@ struct llama_data_read {
}
}
if (!kv_self.v_trans) {
if (kv_self.v_trans == 0) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
@@ -21812,7 +21821,8 @@ struct llama_data_read {
ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row);
}
}
} else {
}
else if (v_trans == 1) {
// For each layer, read the values for each cell (transposed)
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();