From bc4be331ee3fe4f1e1d01c7e93ddcfecd3666caa Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 25 Nov 2025 15:30:37 +0000 Subject: [PATCH] WIP: also allocate the KV cache using tensor split --- ggml/src/ggml-backend.cpp | 2 +- ggml/src/ggml-cuda.cu | 57 ++++++++++++++++++++++++++------------- src/llama.cpp | 2 ++ 3 files changed, 41 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 813f4467..e42b05ec 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -43,7 +43,7 @@ GGML_CALL size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buf // get_alloc_size is optional, defaults to ggml_nbytes if (buft->iface.get_alloc_size) { size_t size = buft->iface.get_alloc_size(buft, tensor); - assert(size >= ggml_nbytes(tensor)); + //assert(size >= ggml_nbytes(tensor)); return size; } return ggml_nbytes(tensor); diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index dd9498fd..d2f669f6 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -1010,31 +1010,50 @@ GGML_CALL static size_t ggml_backend_cuda_split_buffer_type_get_alignment(ggml_b GGML_UNUSED(buft); } -GGML_CALL static size_t ggml_backend_cuda_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { - ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context; +GGML_CALL static size_t ggml_backend_cuda_split_buffer_type_get_alloc_size([[maybe_unused]] ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + if (!tensor->extra) return 0; + auto extra = (ggml_split_tensor_t *)tensor->extra; + GGML_ASSERT(extra->n_device <= ggml_backend_cuda_get_device_count()); size_t total_size = 0; - - const int64_t ne0 = tensor->ne[0]; - - for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { - int64_t row_low, row_high; - get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, id); - - int64_t nrows_split = row_high - row_low; - if (nrows_split == 0) { - continue; - } - - total_size += ggml_nbytes_split(tensor, nrows_split); - - // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses + for (int i = 0; i < extra->n_device; ++i) { + auto split = extra->splits[i]; + if (!split) continue; + total_size += ggml_nbytes(split); + auto ne0 = split->ne[0]; if (ne0 % MATRIX_ROW_PADDING != 0) { - total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); + auto nblock = (ne0 + MATRIX_ROW_PADDING - 1)/MATRIX_ROW_PADDING; + auto row_size = ggml_row_size(split->type, ne0); + auto padded_row_size = ggml_row_size(split->type, nblock*MATRIX_ROW_PADDING); + total_size += padded_row_size - row_size; } } - return total_size; + + //ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context; + + //size_t total_size = 0; + + //const int64_t ne0 = tensor->ne[0]; + + //for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { + // int64_t row_low, row_high; + // get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, id); + + // int64_t nrows_split = row_high - row_low; + // if (nrows_split == 0) { + // continue; + // } + + // total_size += ggml_nbytes_split(tensor, nrows_split); + + // // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses + // if (ne0 % MATRIX_ROW_PADDING != 0) { + // total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); + // } + //} + + //return total_size; } GGML_CALL static bool ggml_backend_cuda_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) { diff --git a/src/llama.cpp b/src/llama.cpp index cce9729d..b214ab2c 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -770,6 +770,8 @@ static bool llama_kv_cache_init( split_v_l.ggml.n_device = extra_V->n_device; split_v_l.ggml.split_dim = 0; split_v_l.ggml.splits = split_v_l.tensor_splits.data(); + k->extra = (void *)&split_k_l.ggml; + v->extra = (void *)&split_v_l.ggml; } else { printf("Oops: don't have yet K and V for layer %d\n", i); }