mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-06 14:30:09 +00:00
* Remove most of split mode row
* WIP
* WIP: also allocate the KV cache using tensor split
* WIP: it runs with wrong result
But it also looks like the backend scheduler is not going to help:
* It copies mask and input positions to GPU 0
* => RoPE ops must run on GPU 0
* => To proceed attn evaluation, GPU 1 must wait for GPU 0 to finish its
entire attn calculation
* Same with FFN. The rms_norm gets scheduled on GPU 0. Hence, GPU 1 must
wait for GPU 0 to finish its entore FFN calculation before it can
start (as it needs to copy the result of rms_norm from GPU 0)
* => Seems useless without writing a bespoke TP scheduling
* WIP
* This works, but it is slow
* This is slightly better
the graph is still not being computed in parallel.
Why? Because the scheduler creates graph splits where the
result of the computation on one GPU becomes an input for the
other split. Hence, to trigger the computation on the second GPU
one needs to wait for the computation on the first GPU to finish,
even thiough the two can be done in parallel up to the sunchronization
point. So, all that is left to do is to trick the scheduler to create
to splits that can be done in parallel, and then have a graph split
where the results get combined.
* Playing games with the scheduler
This change tricks it into doing the right thing^TM.
Still quite a bit slower than split mode layer for the 8B LlaMA model.
But for the 70B LlaMA it now beats split mode layer for TG:
28 t/s vs 24.4 t/s. PP is 627 t/s vs 744 t/s.
In comparison, split mode "row" in mainline gets
484 t/s PP and 19.3 t/s TG.
* Fix attn split
Granularity for Wq, Wo is not just head size, but
head size * gqa_ratio.
Else the Wk, Wv tensors end up not being a multiple of the
head size when we divide the split determined by Wo with
the gqa_ratio.
* Show memory used per device
* Make it work with partial offload
but no tensor overrides yet, just ngl < num_layers.
* Allow for f16 source in fused_rms_norm
* This results in faster PP.
Now PP is faster than split mode layer for L3-70B.
* Rename split mode "row" to split mode "graph"
* Leave FFN partial results as f16
* WIP GLM4.5 - runs with wrong results
* WIP GLM4.5 - this works
PP is already better than split mode layer, but TG for zero context
is kind of low - 60 vs 92 t/s. TG becomes better than split mode layer
at around 20k tokens. PP at 26k tokens is 1.55X of sm layer.
* Work around compiler bug
It issues a warning that there is an extra semicolon outside of a function,
but there isn't. If I remove the anonymous namespace and turn the
functions inside into static, the warning disapears, so clearly
a compiler bug.
* Make graph reuse work with split mode graph
* Remove more split mode row remnants
* WIP tensor overrides
Runs with wrong results, don't see where the issue could be.
* This works but is slow
Still does not work for row-interleaved quants
* Slightly better
* Slightly better
* Row-interleaved quants work
* Better
* Minor
* Guarad against using split mode "graph" for unsupported models
* Guards against using merge_qkv with split mode "graph"
* WIP split mode attn
Works for LlaMA models, but not for GLM-4.5.
Doesn't seem to improve performance, so I guess no point in trying to
fix it.
* Split mode graph for qwen3moe
* Try to better distribute the splits
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
227 lines
6.9 KiB
C++
227 lines
6.9 KiB
C++
#pragma once
|
|
|
|
#include "llama-impl.h"
|
|
#include "llama-cparams.h"
|
|
#include "llama-sampling.h"
|
|
|
|
struct llama_model;
|
|
|
|
#include <vector>
|
|
#include <map>
|
|
#include <set>
|
|
#include <memory>
|
|
|
|
struct llama_kv_cell {
|
|
llama_pos pos = -1;
|
|
llama_pos delta = 0;
|
|
int32_t src = 0; // used by recurrent state models to copy states
|
|
|
|
std::set<llama_seq_id> seq_id;
|
|
|
|
bool has_seq_id(const llama_seq_id & id) const {
|
|
return seq_id.find(id) != seq_id.end();
|
|
}
|
|
|
|
bool is_empty() const {
|
|
return seq_id.empty();
|
|
}
|
|
|
|
bool is_same_seq(const llama_kv_cell & other) const {
|
|
return seq_id == other.seq_id;
|
|
}
|
|
};
|
|
|
|
// ring-buffer of cached KV data
|
|
struct llama_kv_cache {
|
|
bool has_shift = false;
|
|
bool do_defrag = false;
|
|
bool do_copy = false;
|
|
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
|
|
bool v_trans = true; // the value tensor is transposed
|
|
|
|
// Note: The value of head isn't only used to optimize searching
|
|
// for a free KV slot. llama_decode_internal also uses it, so it
|
|
// cannot be freely changed after a slot has been allocated.
|
|
uint32_t head = 0;
|
|
uint32_t size = 0;
|
|
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
|
|
|
// computed before each graph build
|
|
uint32_t n = 0;
|
|
|
|
ggml_type type_k = GGML_TYPE_F16;
|
|
ggml_type type_v = GGML_TYPE_F16;
|
|
|
|
std::vector<llama_kv_cell> cells;
|
|
|
|
std::vector<struct ggml_tensor *> k_l; // per layer
|
|
std::vector<struct ggml_tensor *> v_l;
|
|
|
|
std::vector<llama_split_tensor> split_k_l;
|
|
std::vector<llama_split_tensor> split_v_l;
|
|
|
|
std::vector<struct ggml_context *> ctxs;
|
|
std::vector<ggml_backend_buffer_t> bufs;
|
|
|
|
size_t total_size() const {
|
|
size_t size = 0;
|
|
for (ggml_backend_buffer_t buf : bufs) {
|
|
size += ggml_backend_buffer_get_size(buf);
|
|
}
|
|
return size;
|
|
}
|
|
|
|
~llama_kv_cache() {
|
|
for (struct ggml_context * ctx : ctxs) {
|
|
ggml_free(ctx);
|
|
}
|
|
for (ggml_backend_buffer_t buf : bufs) {
|
|
ggml_backend_buffer_free(buf);
|
|
}
|
|
}
|
|
};
|
|
|
|
struct llama_control_vector {
|
|
std::vector<struct ggml_tensor *> tensors; // per layer
|
|
std::vector<struct ggml_context *> ctxs;
|
|
std::vector<ggml_backend_buffer_t> bufs;
|
|
|
|
int32_t layer_start = -1;
|
|
int32_t layer_end = -1;
|
|
|
|
struct ggml_tensor * tensor_for(int il) const {
|
|
if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
|
|
return nullptr;
|
|
}
|
|
return tensors[il];
|
|
}
|
|
|
|
struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const {
|
|
ggml_tensor * layer_dir = tensor_for(il);
|
|
if (layer_dir != nullptr) {
|
|
cur = ggml_add(ctx, cur, layer_dir);
|
|
}
|
|
return cur;
|
|
}
|
|
|
|
~llama_control_vector() {
|
|
for (struct ggml_context * ctx : ctxs) {
|
|
ggml_free(ctx);
|
|
}
|
|
for (ggml_backend_buffer_t buf : bufs) {
|
|
ggml_backend_buffer_free(buf);
|
|
}
|
|
}
|
|
};
|
|
|
|
struct llama_context {
|
|
|
|
llama_context(const llama_model & model);
|
|
|
|
~llama_context();
|
|
|
|
const struct llama_model & model;
|
|
|
|
struct llama_cparams cparams;
|
|
struct llama_sampling sampling;
|
|
struct llama_kv_cache kv_self;
|
|
struct llama_control_vector cvec;
|
|
|
|
std::vector<float> scale_data;
|
|
|
|
std::unordered_map<struct llama_lora_adapter *, float> lora_adapters;
|
|
|
|
std::vector<ggml_backend_t> backends;
|
|
#ifdef GGML_USE_METAL
|
|
ggml_backend_t backend_metal = nullptr;
|
|
#endif
|
|
#ifdef GGML_USE_BLAS
|
|
ggml_backend_t backend_blas = nullptr;
|
|
#endif
|
|
ggml_backend_t backend_cpu = nullptr;
|
|
|
|
bool has_evaluated_once = false;
|
|
|
|
int64_t t_start_us;
|
|
int64_t t_load_us;
|
|
int64_t t_p_eval_us = 0;
|
|
int64_t t_eval_us = 0;
|
|
|
|
int64_t t_compute_start_us = 0;
|
|
int64_t n_queued_tokens = 0;
|
|
|
|
int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
|
|
int32_t n_eval = 0; // number of eval calls
|
|
|
|
// host buffer for the model output (logits and embeddings)
|
|
ggml_backend_buffer_t buf_output = nullptr;
|
|
|
|
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
|
size_t logits_size = 0; // capacity (of floats) for logits
|
|
float * logits = nullptr;
|
|
|
|
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
|
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
|
|
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
|
|
|
bool logits_all = false;
|
|
|
|
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
|
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
|
size_t embd_size = 0; // capacity (of floats) for embeddings
|
|
float * embd = nullptr;
|
|
|
|
// sequence embeddings output (map of [n_embd] vectors)
|
|
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
|
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
|
|
|
// whether we are computing encoder output or decoder output
|
|
bool is_encoding = false;
|
|
|
|
// output of the encoder part of the encoder-decoder models
|
|
std::vector<float> embd_enc;
|
|
std::vector<std::set<llama_seq_id>> seq_ids_enc;
|
|
|
|
// memory buffers used to evaluate the model
|
|
std::vector<uint8_t> buf_compute_meta;
|
|
ggml_backend_sched_t sched = nullptr;
|
|
|
|
ggml_abort_callback abort_callback = nullptr;
|
|
void * abort_callback_data = nullptr;
|
|
|
|
// input tensors
|
|
struct ggml_tensor * inp_tokens; // I32 [n_batch]
|
|
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
|
|
struct ggml_tensor * inp_pos; // I32 [n_batch]
|
|
struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
|
|
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
|
|
struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
|
|
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
|
|
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
|
|
struct ggml_tensor * inp_cls; // I32 [n_batch]
|
|
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
|
|
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
|
|
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
|
|
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
|
|
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
|
|
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
|
|
struct ggml_tensor * inp_scale = nullptr; // F32 [n_tokens]
|
|
|
|
ggml_backend_t ggml_backend_by_name(const char * name);
|
|
|
|
struct Prev;
|
|
std::unique_ptr<Prev> prev;
|
|
|
|
void reset_scheduler();
|
|
bool can_reuse_graph(const llama_batch & u_batch);
|
|
|
|
struct CacheCopy {
|
|
ggml_tensor * cpy = nullptr;
|
|
size_t step = 0;
|
|
};
|
|
std::vector<CacheCopy> cache_copies;
|
|
|
|
bool update_cache_copies();
|
|
|
|
};
|