mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-11 06:20:09 +00:00
Replace MLA-specific KV cache with the standard KV cache V2 (#473)
* Fix save and restore when there is no V cache * Fix double print
This commit is contained in:
@@ -20699,41 +20699,21 @@ struct llama_context * llama_new_context_with_model(
|
||||
}
|
||||
|
||||
if (memory_size_k + memory_size_v > 0) {
|
||||
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
||||
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
size_t memory_size_kv = 0;
|
||||
size_t memory_size_kvt = 0;
|
||||
|
||||
ggml_type kv_type = GGML_TYPE_COUNT;
|
||||
ggml_type kvt_type = GGML_TYPE_COUNT;
|
||||
|
||||
for (auto & kv : ctx->kv_self.k_l) {
|
||||
memory_size_kv += ggml_nbytes(kv);
|
||||
kv_type = kv->type;
|
||||
}
|
||||
|
||||
for (auto & kvt : ctx->kv_self.v_l) {
|
||||
memory_size_kvt += ggml_nbytes(kvt);
|
||||
kvt_type = kvt->type;
|
||||
}
|
||||
|
||||
if (memory_size_kv + memory_size_kvt > 0) {
|
||||
if (cparams.mla_attn == 1 && !cparams.flash_attn) {
|
||||
if (cparams.mla_attn != 0 && !cparams.flash_attn) {
|
||||
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__,
|
||||
(float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f),
|
||||
ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f),
|
||||
ggml_type_name(kvt_type), (float)memory_size_kvt / (1024.0f * 1024.0f));
|
||||
} else {
|
||||
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
||||
} else if (cparams.mla_attn != 0 && cparams.flash_attn) {
|
||||
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T: not used\n", __func__,
|
||||
(float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f),
|
||||
ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f));
|
||||
}
|
||||
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f));
|
||||
} else {
|
||||
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
||||
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21450,13 +21430,13 @@ struct llama_data_write {
|
||||
const struct llama_kv_cache & kv_self = ctx->kv_self;
|
||||
const struct llama_hparams & hparams = ctx->model.hparams;
|
||||
|
||||
// Misuse v_trans: 0 -> not transposed V cache
|
||||
// 1 -> transposed V cache
|
||||
// 2 -> no V cache (as it maybe the case with MLA)
|
||||
const uint32_t v_trans = kv_self.v_l.empty() ? 2 : kv_self.v_trans ? 1 : 0;
|
||||
// v_state: 0 -> not transposed V cache
|
||||
// 1 -> transposed V cache
|
||||
// 2 -> no V cache (as it may be the case with MLA)
|
||||
const uint32_t v_state = kv_self.v_l.empty() ? 2 : kv_self.v_trans ? 1 : 0;
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
|
||||
write(&v_trans, sizeof(v_trans));
|
||||
write(&v_state, sizeof(v_state));
|
||||
write(&n_layer, sizeof(n_layer));
|
||||
|
||||
std::vector<uint8_t> tmp_buf;
|
||||
@@ -21482,7 +21462,7 @@ struct llama_data_write {
|
||||
}
|
||||
}
|
||||
|
||||
if (kv_self.v_trans == 0) {
|
||||
if (v_state == 0) {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
|
||||
@@ -21502,7 +21482,7 @@ struct llama_data_write {
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (v_trans == 1) {
|
||||
else if (v_state == 1) {
|
||||
// When v is transposed, we also need the element size and get the element ranges from each row
|
||||
const uint32_t kv_size = kv_self.size;
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
@@ -21748,9 +21728,13 @@ struct llama_data_read {
|
||||
bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) {
|
||||
const struct llama_hparams & hparams = ctx->model.hparams;
|
||||
struct llama_kv_cache & kv_self = ctx->kv_self;
|
||||
uint32_t v_trans;
|
||||
|
||||
// v_state: 0 -> not transposed V cache
|
||||
// 1 -> transposed V cache
|
||||
// 2 -> no V cache (as it may be the case with MLA)
|
||||
uint32_t v_state;
|
||||
uint32_t n_layer;
|
||||
read_to(&v_trans, sizeof(v_trans));
|
||||
read_to(&v_state, sizeof(v_state));
|
||||
read_to(&n_layer, sizeof(n_layer));
|
||||
|
||||
if (n_layer != hparams.n_layer) {
|
||||
@@ -21761,7 +21745,9 @@ struct llama_data_read {
|
||||
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size);
|
||||
return false;
|
||||
}
|
||||
if (kv_self.v_trans != (bool) v_trans) {
|
||||
|
||||
// Currently the only way there is no V cache (and thus v_state is 2) requires flash_attn, and flash_attn sets kv_self.v_trans to false
|
||||
if (kv_self.v_trans != (v_state == 1)) {
|
||||
LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
|
||||
return false;
|
||||
}
|
||||
@@ -21794,7 +21780,7 @@ struct llama_data_read {
|
||||
}
|
||||
}
|
||||
|
||||
if (kv_self.v_trans == 0) {
|
||||
if (v_state == 0) {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
|
||||
@@ -21822,7 +21808,7 @@ struct llama_data_read {
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (v_trans == 1) {
|
||||
else if (v_state == 1) {
|
||||
// For each layer, read the values for each cell (transposed)
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
|
||||
Reference in New Issue
Block a user