mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-27 09:53:40 +00:00
Hopefully take care of missing V cache (MLA)
This commit is contained in:
@@ -10347,34 +10347,39 @@ struct llm_build_context {
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id));
|
||||
|
||||
ggml_tensor * view_v_src;
|
||||
ggml_tensor * view_v_dst;
|
||||
ggml_tensor * view_v_src = nullptr;
|
||||
ggml_tensor * view_v_dst = nullptr;
|
||||
|
||||
if (flash_attn) {
|
||||
// NOTE: the V cache is not transposed when using flash attention
|
||||
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
n_embd_v_gqa, nm,
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i));
|
||||
if (kv_self.v_l.size() > il) {
|
||||
// Note: with MLA the V cache may not be present.
|
||||
if (flash_attn) {
|
||||
// NOTE: the V cache is not transposed when using flash attention
|
||||
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
n_embd_v_gqa, nm,
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i));
|
||||
|
||||
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
n_embd_v_gqa, nm,
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id));
|
||||
} else {
|
||||
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
nm, n_embd_v_gqa,
|
||||
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
|
||||
ggml_row_size(kv_self.v_l[il]->type, i));
|
||||
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
n_embd_v_gqa, nm,
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id));
|
||||
} else {
|
||||
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
nm, n_embd_v_gqa,
|
||||
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
|
||||
ggml_row_size(kv_self.v_l[il]->type, i));
|
||||
|
||||
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
nm, n_embd_v_gqa,
|
||||
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
|
||||
ggml_row_size(kv_self.v_l[il]->type, id));
|
||||
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
nm, n_embd_v_gqa,
|
||||
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
|
||||
ggml_row_size(kv_self.v_l[il]->type, id));
|
||||
}
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
|
||||
if (view_v_src && view_v_dst) {
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
|
||||
}
|
||||
}
|
||||
|
||||
i += nm - 1;
|
||||
@@ -21445,7 +21450,10 @@ struct llama_data_write {
|
||||
const struct llama_kv_cache & kv_self = ctx->kv_self;
|
||||
const struct llama_hparams & hparams = ctx->model.hparams;
|
||||
|
||||
const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
|
||||
// Misuse v_trans: 0 -> not transposed V cache
|
||||
// 1 -> transposed V cache
|
||||
// 2 -> no V cache (as it maybe the case with MLA)
|
||||
const uint32_t v_trans = kv_self.v_l.empty() ? 2 : kv_self.v_trans ? 1 : 0;
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
|
||||
write(&v_trans, sizeof(v_trans));
|
||||
@@ -21474,7 +21482,7 @@ struct llama_data_write {
|
||||
}
|
||||
}
|
||||
|
||||
if (!kv_self.v_trans) {
|
||||
if (kv_self.v_trans == 0) {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
|
||||
@@ -21493,7 +21501,8 @@ struct llama_data_write {
|
||||
write_tensor_data(kv_self.v_l[il], range.first * v_size_row, buf_size);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
}
|
||||
else if (v_trans == 1) {
|
||||
// When v is transposed, we also need the element size and get the element ranges from each row
|
||||
const uint32_t kv_size = kv_self.size;
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
@@ -21785,7 +21794,7 @@ struct llama_data_read {
|
||||
}
|
||||
}
|
||||
|
||||
if (!kv_self.v_trans) {
|
||||
if (kv_self.v_trans == 0) {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
|
||||
@@ -21812,7 +21821,8 @@ struct llama_data_read {
|
||||
ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
}
|
||||
else if (v_trans == 1) {
|
||||
// For each layer, read the values for each cell (transposed)
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
|
||||
Reference in New Issue
Block a user