mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-27 08:34:09 +00:00
Deepseek MLA Optimizations
Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
This commit is contained in:
committed by
Iwan Kawrakow
parent
6d7b58eade
commit
35246c4e75
@@ -3123,6 +3123,7 @@ class ArcticModel(Model):
|
||||
|
||||
|
||||
@Model.register("DeepseekV2ForCausalLM")
|
||||
@Model.register("DeepseekV3ForCausalLM")
|
||||
class DeepseekV2Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
|
||||
|
||||
@@ -3144,6 +3145,15 @@ class DeepseekV2Model(Model):
|
||||
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
|
||||
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
|
||||
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
|
||||
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||
|
||||
if hparams["scoring_func"] == "sigmoid":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
elif hparams["scoring_func"] == "softmax":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}")
|
||||
|
||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||
|
||||
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
|
||||
@@ -3156,6 +3166,17 @@ class DeepseekV2Model(Model):
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# rename e_score_correction_bias tensors
|
||||
if name.endswith("e_score_correction_bias"):
|
||||
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
|
||||
|
||||
# skip Multi-Token Prediction (MTP) layers
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
match = re.match(r"model.layers.(\d+)", name)
|
||||
if match and int(match.group(1)) >= block_count:
|
||||
return []
|
||||
|
||||
|
||||
# process the experts separately
|
||||
if name.find("mlp.experts") != -1:
|
||||
n_experts = self.hparams["n_routed_experts"]
|
||||
@@ -3188,6 +3209,27 @@ class DeepseekV2Model(Model):
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
if name.endswith("kv_b_proj.weight"):
|
||||
name_kb = name.replace("kv_b_proj", "k_b_proj")
|
||||
name_vb = name.replace("kv_b_proj", "v_b_proj")
|
||||
|
||||
n_head_kv = self.hparams["num_key_value_heads"]
|
||||
v_head_dim = self.hparams["v_head_dim"]
|
||||
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
|
||||
|
||||
assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)
|
||||
|
||||
kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
|
||||
k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
|
||||
k_b = k_b.transpose(1, 2)
|
||||
k_b = k_b.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim)
|
||||
v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1])
|
||||
|
||||
return [
|
||||
(self.map_tensor_name(name), data_torch),
|
||||
(self.map_tensor_name(name_kb), k_b),
|
||||
(self.map_tensor_name(name_vb), v_b)
|
||||
]
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@@ -274,6 +274,8 @@ class MODEL_TENSOR(IntEnum):
|
||||
ATTN_Q_B = auto()
|
||||
ATTN_KV_A_MQA = auto()
|
||||
ATTN_KV_B = auto()
|
||||
ATTN_K_B = auto()
|
||||
ATTN_V_B = auto()
|
||||
ATTN_Q_A_NORM = auto()
|
||||
ATTN_KV_A_NORM = auto()
|
||||
FFN_SUB_NORM = auto()
|
||||
@@ -403,6 +405,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
|
||||
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
|
||||
MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b",
|
||||
MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b",
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
|
||||
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
|
||||
@@ -967,6 +971,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.ATTN_Q_B,
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA,
|
||||
MODEL_TENSOR.ATTN_KV_B,
|
||||
MODEL_TENSOR.ATTN_K_B,
|
||||
MODEL_TENSOR.ATTN_V_B,
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM,
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
|
||||
171
src/llama.cpp
171
src/llama.cpp
@@ -539,6 +539,8 @@ enum llm_tensor {
|
||||
LLM_TENSOR_ATTN_Q_B,
|
||||
LLM_TENSOR_ATTN_KV_A_MQA,
|
||||
LLM_TENSOR_ATTN_KV_B,
|
||||
LLM_TENSOR_ATTN_K_B,
|
||||
LLM_TENSOR_ATTN_V_B,
|
||||
LLM_TENSOR_ATTN_Q_A_NORM,
|
||||
LLM_TENSOR_ATTN_KV_A_NORM,
|
||||
LLM_TENSOR_ATTN_SUB_NORM,
|
||||
@@ -1203,6 +1205,8 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
|
||||
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
|
||||
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
|
||||
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
|
||||
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
@@ -2541,6 +2545,8 @@ struct llama_layer {
|
||||
struct ggml_tensor * wq_b;
|
||||
struct ggml_tensor * wkv_a_mqa;
|
||||
struct ggml_tensor * wkv_b;
|
||||
struct ggml_tensor * wk_b;
|
||||
struct ggml_tensor * wv_b;
|
||||
struct ggml_tensor * wq_cross;
|
||||
struct ggml_tensor * wk_cross;
|
||||
struct ggml_tensor * wv_cross;
|
||||
@@ -2669,11 +2675,19 @@ struct llama_kv_cache {
|
||||
ggml_type type_k = GGML_TYPE_F16;
|
||||
ggml_type type_v = GGML_TYPE_F16;
|
||||
|
||||
ggml_type type_kr = GGML_TYPE_F16;
|
||||
ggml_type type_kv = GGML_TYPE_F16;
|
||||
|
||||
std::vector<llama_kv_cell> cells;
|
||||
|
||||
std::vector<struct ggml_tensor *> k_l; // per layer
|
||||
std::vector<struct ggml_tensor *> v_l;
|
||||
|
||||
// DeepSeek MLA
|
||||
std::vector<struct ggml_tensor *> kr_l; // per layer
|
||||
std::vector<struct ggml_tensor *> kv_l;
|
||||
std::vector<struct ggml_tensor *> kvt_l;
|
||||
|
||||
std::vector<struct ggml_context *> ctxs;
|
||||
std::vector<ggml_backend_buffer_t> bufs;
|
||||
|
||||
@@ -3132,7 +3146,7 @@ static bool llama_kv_cache_init(
|
||||
for (auto & it : buft_layer_count) {
|
||||
int n_layers = it.second;
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ 2u*n_layers*ggml_tensor_overhead(),
|
||||
/*.mem_size =*/ 5u*n_layers*ggml_tensor_overhead(),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
@@ -3148,6 +3162,11 @@ static bool llama_kv_cache_init(
|
||||
cache.k_l.reserve(n_layer);
|
||||
cache.v_l.reserve(n_layer);
|
||||
|
||||
// DeepSeek MLA
|
||||
cache.kr_l.reserve(n_layer);
|
||||
cache.kv_l.reserve(n_layer);
|
||||
cache.kvt_l.reserve(n_layer);
|
||||
|
||||
for (int i = 0; i < (int) n_layer; i++) {
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
|
||||
@@ -3159,6 +3178,21 @@ static bool llama_kv_cache_init(
|
||||
ggml_format_name(v, "cache_v_l%d", i);
|
||||
cache.k_l.push_back(k);
|
||||
cache.v_l.push_back(v);
|
||||
|
||||
|
||||
// DeepSeek MLA
|
||||
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||
LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
|
||||
ggml_tensor * kr = ggml_new_tensor_1d(ctx, cache.type_kr, n_embd_head_qk_rope*kv_size);
|
||||
ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size);
|
||||
ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size);
|
||||
ggml_format_name(kr, "cache_kr_l%d", i);
|
||||
ggml_format_name(kv, "cache_kv_l%d", i);
|
||||
ggml_format_name(kvt, "cache_kvt_l%d", i);
|
||||
cache.kr_l.push_back(kr);
|
||||
cache.kv_l.push_back(kv);
|
||||
cache.kvt_l.push_back(kvt);
|
||||
}
|
||||
|
||||
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
||||
@@ -7644,6 +7678,8 @@ static bool llm_load_tensors(
|
||||
|
||||
layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)});
|
||||
layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)});
|
||||
layer.wk_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 0);
|
||||
layer.wv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v}, 0);
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd});
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
@@ -13396,31 +13432,31 @@ struct llm_build_context {
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(kv_compressed, "kv_compressed", il);
|
||||
|
||||
// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
|
||||
struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
|
||||
cb(kv, "kv", il);
|
||||
struct ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)*kv_head);
|
||||
cb(kv_cache_view, "kv_cache_view", il);
|
||||
|
||||
// split into {n_head * n_embd_head_qk_nope, n_tokens}
|
||||
struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
|
||||
ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
|
||||
ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||
0);
|
||||
cb(k_nope, "k_nope", il);
|
||||
// note: storing c^KV in the KV cache
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, kv_compressed, kv_cache_view));
|
||||
|
||||
// and {n_head * n_embd_head_v, n_tokens}
|
||||
struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
|
||||
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
|
||||
ggml_row_size(kv->type, (n_embd_head_qk_nope)));
|
||||
cb(v_states, "v_states", il);
|
||||
struct ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), ggml_row_size(kv_self.kv_l[il]->type, kv_head));
|
||||
cb(kv_cache_trans_view, "kv_cache_trans_view", il);
|
||||
|
||||
v_states = ggml_cont(ctx0, v_states);
|
||||
cb(v_states, "v_states", il);
|
||||
// note: storing transposed c^KV in the transposed KV cache
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view));
|
||||
|
||||
v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
|
||||
ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
|
||||
0);
|
||||
cb(v_states, "v_states", il);
|
||||
struct ggml_tensor * kv_cache =
|
||||
ggml_view_2d(ctx0, kv_self.kv_l[il],
|
||||
kv_lora_rank, n_kv,
|
||||
ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank),
|
||||
0);
|
||||
cb(kv_cache, "kv_cache", il);
|
||||
|
||||
struct ggml_tensor * kv_cache_trans =
|
||||
ggml_view_2d(ctx0, kv_self.kvt_l[il],
|
||||
n_kv, kv_lora_rank,
|
||||
ggml_row_size(kv_self.kv_l[il]->type, kv_self.size),
|
||||
0);
|
||||
cb(kv_cache_trans, "kv_cache_trans", il);
|
||||
|
||||
//q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
||||
q_pe = ggml_rope_ext(
|
||||
@@ -13439,15 +13475,74 @@ struct llm_build_context {
|
||||
);
|
||||
cb(k_pe, "k_pe", il);
|
||||
|
||||
struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
|
||||
cb(q_states, "q_states", il);
|
||||
struct ggml_tensor * kr_cache_view = ggml_view_1d(ctx0, kv_self.kr_l[il], n_tokens*n_embd_head_qk_rope, ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope)*kv_head);
|
||||
cb(kr_cache_view, "kr_cache_view", il);
|
||||
|
||||
struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
|
||||
cb(k_states, "k_states", il);
|
||||
// note: storing RoPE-ed version of K^R in the KV cache
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_pe, kr_cache_view));
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
|
||||
struct ggml_tensor * kr_cache =
|
||||
ggml_view_2d(ctx0, kv_self.kr_l[il],
|
||||
n_embd_head_qk_rope, n_kv,
|
||||
ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope),
|
||||
0);
|
||||
cb(kr_cache, "kr_cache", il);
|
||||
|
||||
struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), 0);
|
||||
cb(wk_b, "wk_b", il);
|
||||
|
||||
struct ggml_tensor * q_nope_perm = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
|
||||
cb(q_nope_perm, "q_nope_perm", il);
|
||||
|
||||
struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope_perm);
|
||||
cb(q_nope2, "q_nope2", il);
|
||||
|
||||
struct ggml_tensor * q_nope2_perm = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3);
|
||||
cb(q_nope2_perm, "q_nope2_perm", il);
|
||||
|
||||
struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2_perm);
|
||||
cb(kq_nope, "kq_nope", il);
|
||||
|
||||
struct ggml_tensor * q_pe_perm = ggml_permute(ctx0, q_pe, 0, 3, 2, 1);
|
||||
cb(q_pe_perm, "q_pe_perm", il);
|
||||
|
||||
struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe);
|
||||
cb(kq_pe, "kq_pe", il);
|
||||
|
||||
struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
|
||||
cb(kq, "kq_perm", il);
|
||||
|
||||
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||
cb(kq, "kq_soft_max_ext", il);
|
||||
|
||||
struct ggml_tensor * kq_perm = ggml_permute(ctx0, kq, 0, 2, 1, 3);
|
||||
cb(kq_perm, "kq_soft_max_ext_perm", il);
|
||||
|
||||
struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq_perm);
|
||||
cb(kqv_compressed, "kqv_compressed", il);
|
||||
|
||||
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
|
||||
cb(kqv_compressed, "kqv_compressed_perm", il);
|
||||
|
||||
struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0);
|
||||
cb(wv_b, "wv_b", il);
|
||||
|
||||
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed);
|
||||
cb(kqv, "kqv", il);
|
||||
|
||||
kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3));
|
||||
cb(kqv, "kqv_perm", il);
|
||||
|
||||
cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0);
|
||||
cb(cur, "kqv_2d", il);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
@@ -17853,6 +17948,24 @@ struct llama_context * llama_new_context_with_model(
|
||||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
||||
}
|
||||
|
||||
{
|
||||
size_t memory_size_kr = 0;
|
||||
size_t memory_size_kv = 0;
|
||||
|
||||
for (auto & kr : ctx->kv_self.kr_l) {
|
||||
memory_size_kr += ggml_nbytes(kr);
|
||||
}
|
||||
|
||||
for (auto & kv : ctx->kv_self.kv_l) {
|
||||
memory_size_kv += ggml_nbytes(kv);
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K^R (%s): %7.2f MiB, c^KV (%s): %7.2f MiB\n", __func__,
|
||||
(float)(memory_size_kr + memory_size_kv) / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_k), (float)memory_size_kr / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_k), (float)memory_size_kv / (1024.0f * 1024.0f));
|
||||
}
|
||||
|
||||
// graph outputs buffer
|
||||
{
|
||||
// resized during inference when a batch uses more outputs
|
||||
|
||||
Reference in New Issue
Block a user