Handle split cache (read)

This commit is contained in:
Kawrakow
2025-12-08 10:55:35 +02:00
parent 1e50392cd0
commit be8e7057b3

View File

@@ -5716,6 +5716,41 @@ struct llama_data_read {
return true;
}
void read_kv_cache_data_split(llama_context * ctx, ggml_tensor * tensor, const uint8_t * data, size_t head, size_t row_size, int nrows, int il) {
GGML_ASSERT(il >= 0 && il < int(ctx->model.layers.size()));
GGML_ASSERT(ggml_internal_get_type_traits(tensor->type).row_meta_size == 0);
auto kv = tensor->ne[1] > 1 ? ctx->model.layers[il].wk : ctx->model.layers[il].wv;
auto extra = (ggml_split_tensor_t *)tensor->extra;
auto kv_extra = (ggml_split_tensor_t *)kv->extra;
GGML_ASSERT(extra && kv_extra);
auto ne = kv->ne[1];
size_t sum_ne = 0;
size_t sum_split_row_size = 0;
GGML_ASSERT(row_size == ggml_row_size(tensor->type, ne));
std::vector<uint8_t> aux;
for (int id = 0; id < extra->n_device; ++id) {
auto split = extra->splits[id];
GGML_ASSERT(split->type == tensor->type);
auto kv_split = kv_extra->splits[id];
GGML_ASSERT((split && kv_split) || (!split && !kv_split));
if (!split) continue;
auto split_row_size = ggml_row_size(tensor->type, kv_split->ne[1]);
aux.resize(split_row_size*nrows);
auto src = data + sum_split_row_size;
auto dst = aux.data();
for (int row = 0; row < nrows; ++row) {
std::memcpy(dst, src, split_row_size);
dst += split_row_size;
src += row_size;
}
ggml_backend_tensor_set(split, aux.data(), head*split_row_size, nrows*split_row_size);
sum_ne += kv_split->ne[1];
sum_split_row_size += split_row_size;
}
GGML_ASSERT(sum_ne == ne);
GGML_ASSERT(sum_split_row_size == row_size);
}
bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) {
const struct llama_hparams & hparams = ctx->model.hparams;
struct llama_kv_cache & kv_self = ctx->kv_self;
@@ -5770,7 +5805,11 @@ struct llama_data_read {
if (cell_count) {
// Read and set the keys for the whole cell range
ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row);
if (kv_self.k_l[il]->extra) {
read_kv_cache_data_split(ctx, kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head, k_size_row, cell_count, il);
} else {
ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row);
}
}
}
@@ -5798,7 +5837,11 @@ struct llama_data_read {
if (cell_count) {
// Read and set the values for the whole cell range
ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row);
if (kv_self.v_l[il]->extra) {
read_kv_cache_data_split(ctx, kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head, v_size_row, cell_count, il);
} else {
ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row);
}
}
}
}
@@ -5834,6 +5877,9 @@ struct llama_data_read {
}
if (cell_count) {
if (kv_self.v_l[il]->extra) {
throw std::runtime_error("Transposed V cache is not sypported with split mode 'graph'");
}
// For each row in the transposed matrix, read the values for the whole cell range
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el;