server: enable checkpoint for recurrent models (#1310)

* server: enable checkpoint for recurrent models

create checkpoint after cancel

fix ban string and rm context during rewind

add checkpoint interval

only save recurrent cache

* save checkpoint during pp

---------

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana
2026-02-25 23:51:18 -06:00
committed by GitHub
parent 216f44363f
commit 3fac78c48b
11 changed files with 204 additions and 111 deletions

View File

@@ -645,6 +645,8 @@ extern "C" {
// Returns true if the model is hybrid (like Jamba, Granite, etc.)
LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model);
LLAMA_API bool llama_model_has_recurrent(const struct llama_model * model);
// Returns 0 on success
LLAMA_API uint32_t llama_model_quantize(
const char * fname_inp,
@@ -735,6 +737,11 @@ extern "C" {
llama_seq_id * cells_sequences;
};
// work only with partial states, such as recurrent cache (e.g. Mamba)
#define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1
typedef uint32_t llama_state_seq_flags;
// Create an empty KV cache view. (use only for debugging purposes)
LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max);
@@ -813,6 +820,11 @@ extern "C" {
struct llama_context * ctx,
llama_seq_id seq_id);
// Returns the smallest position present in the KV cache for the specified sequence
LLAMA_API llama_pos llama_kv_cache_seq_pos_min(
struct llama_context * ctx,
llama_seq_id seq_id);
// Defragment the KV cache
// This will be applied:
// - lazily on next llama_decode()
@@ -889,14 +901,16 @@ extern "C" {
// Get the exact size needed to copy the KV cache of a single sequence
LLAMA_API size_t llama_state_seq_get_size(
struct llama_context * ctx,
llama_seq_id seq_id);
llama_seq_id seq_id,
llama_state_seq_flags flags);
// Copy the KV cache of a single sequence into the specified buffer
LLAMA_API size_t llama_state_seq_get_data(
struct llama_context * ctx,
uint8_t * dst,
size_t size,
llama_seq_id seq_id);
llama_seq_id seq_id,
llama_state_seq_flags flags);
// Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
// Returns:
@@ -906,7 +920,8 @@ extern "C" {
struct llama_context * ctx,
const uint8_t * src,
size_t size,
llama_seq_id dest_seq_id);
llama_seq_id dest_seq_id,
llama_state_seq_flags flags);
LLAMA_API size_t llama_state_seq_save_file(
struct llama_context * ctx,