Replace MLA-specific KV cache with the standard KV cache V2 (#473)

* Fix save and restore when there is no V cache

* Fix double print
This commit is contained in:
saood06
2025-05-30 02:28:27 -05:00
committed by GitHub
parent ac27355e3b
commit df257a07e6

View File

@@ -20699,41 +20699,21 @@ struct llama_context * llama_new_context_with_model(
}
if (memory_size_k + memory_size_v > 0) {
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
}
{
size_t memory_size_kv = 0;
size_t memory_size_kvt = 0;
ggml_type kv_type = GGML_TYPE_COUNT;
ggml_type kvt_type = GGML_TYPE_COUNT;
for (auto & kv : ctx->kv_self.k_l) {
memory_size_kv += ggml_nbytes(kv);
kv_type = kv->type;
}
for (auto & kvt : ctx->kv_self.v_l) {
memory_size_kvt += ggml_nbytes(kvt);
kvt_type = kvt->type;
}
if (memory_size_kv + memory_size_kvt > 0) {
if (cparams.mla_attn == 1 && !cparams.flash_attn) {
if (cparams.mla_attn != 0 && !cparams.flash_attn) {
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__,
(float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f),
ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f),
ggml_type_name(kvt_type), (float)memory_size_kvt / (1024.0f * 1024.0f));
} else {
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
} else if (cparams.mla_attn != 0 && cparams.flash_attn) {
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T: not used\n", __func__,
(float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f),
ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f));
}
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f));
} else {
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
}
}
@@ -21450,13 +21430,13 @@ struct llama_data_write {
const struct llama_kv_cache & kv_self = ctx->kv_self;
const struct llama_hparams & hparams = ctx->model.hparams;
// Misuse v_trans: 0 -> not transposed V cache
// 1 -> transposed V cache
// 2 -> no V cache (as it maybe the case with MLA)
const uint32_t v_trans = kv_self.v_l.empty() ? 2 : kv_self.v_trans ? 1 : 0;
// v_state: 0 -> not transposed V cache
// 1 -> transposed V cache
// 2 -> no V cache (as it may be the case with MLA)
const uint32_t v_state = kv_self.v_l.empty() ? 2 : kv_self.v_trans ? 1 : 0;
const uint32_t n_layer = hparams.n_layer;
write(&v_trans, sizeof(v_trans));
write(&v_state, sizeof(v_state));
write(&n_layer, sizeof(n_layer));
std::vector<uint8_t> tmp_buf;
@@ -21482,7 +21462,7 @@ struct llama_data_write {
}
}
if (kv_self.v_trans == 0) {
if (v_state == 0) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
@@ -21502,7 +21482,7 @@ struct llama_data_write {
}
}
}
else if (v_trans == 1) {
else if (v_state == 1) {
// When v is transposed, we also need the element size and get the element ranges from each row
const uint32_t kv_size = kv_self.size;
for (uint32_t il = 0; il < n_layer; ++il) {
@@ -21748,9 +21728,13 @@ struct llama_data_read {
bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) {
const struct llama_hparams & hparams = ctx->model.hparams;
struct llama_kv_cache & kv_self = ctx->kv_self;
uint32_t v_trans;
// v_state: 0 -> not transposed V cache
// 1 -> transposed V cache
// 2 -> no V cache (as it may be the case with MLA)
uint32_t v_state;
uint32_t n_layer;
read_to(&v_trans, sizeof(v_trans));
read_to(&v_state, sizeof(v_state));
read_to(&n_layer, sizeof(n_layer));
if (n_layer != hparams.n_layer) {
@@ -21761,7 +21745,9 @@ struct llama_data_read {
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size);
return false;
}
if (kv_self.v_trans != (bool) v_trans) {
// Currently the only way there is no V cache (and thus v_state is 2) requires flash_attn, and flash_attn sets kv_self.v_trans to false
if (kv_self.v_trans != (v_state == 1)) {
LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
return false;
}
@@ -21794,7 +21780,7 @@ struct llama_data_read {
}
}
if (kv_self.v_trans == 0) {
if (v_state == 0) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
@@ -21822,7 +21808,7 @@ struct llama_data_read {
}
}
}
else if (v_trans == 1) {
else if (v_state == 1) {
// For each layer, read the values for each cell (transposed)
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();