diff --git a/common/common.cpp b/common/common.cpp index 1e761b6d..0498bf9f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1272,6 +1272,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.validate_quants = true; return true; } + if (arg == "-mqkv" || arg == "--merge-qkv") { + params.merge_qkv = true; + return true; + } if (arg == "--numa") { CHECK_ARG std::string value(argv[i]); @@ -1911,6 +1915,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-no-fug, --no-fused-up-gate", "disaable fused up-gate (default: %s)", params.fused_up_gate ? "enabled" : "disabled" }); options.push_back({ "*", "-no-mmad, --no-fused-mul-multiadd", "disaable fused mul-multi_add (default: %s)", params.fused_mmad? "enabled" : "disabled" }); options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts}); + options.push_back({ "*", "-mqkv, --merge-qkv,", "merge Q,K,V (default: %d)", params.merge_qkv}); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" "in conversation mode, this will be used as system prompt\n" "(default: '%s')", params.prompt.c_str() }); @@ -2778,7 +2783,7 @@ void llama_lora_adapters_apply(struct llama_context * ctx, std::vectortype == wk->type && wq->type == wv->type && hparams.f_attention_scale == 0.0f) { + if (ml.merge_qkv && wq->type == wk->type && wq->type == wv->type && hparams.f_attention_scale == 0.0f) { GGML_ASSERT(wq->ne[0] == n_embd && wq->ne[1] == n_head * n_embd_head_k); GGML_ASSERT(wk->ne[0] == n_embd && wk->ne[1] == n_embd_gqa); GGML_ASSERT(wv->ne[0] == n_embd && wv->ne[1] == n_embd_gqa); @@ -2454,7 +2454,7 @@ bool create_tensors_helper::merge_qkv(const LLM_TN & tn, int i, int bias) { layer.wk = ml.create_tensor_as_view(ctx_split, layer.wqkv, wk_name.c_str(), { wk->ne[0], wk->ne[1] }, wq->ne[1]*wq->nb[1]); layer.wv = ml.create_tensor_as_view(ctx_split, layer.wqkv, wv_name.c_str(), { wv->ne[0], wv->ne[1] }, wq->ne[1]*wq->nb[1] + wk->ne[1]*wk->nb[1] ); fused_qkv = true; - printf("Created fused qkv %s\n", layer.wqkv->name); + printf("Created merged qkv %s\n", layer.wqkv->name); if (bias) { auto bq_name = tn(LLM_TENSOR_ATTN_Q, "bias", i); auto bk_name = tn(LLM_TENSOR_ATTN_K, "bias", i); diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 4fd26f2d..3f1ded18 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -203,9 +203,10 @@ namespace GGUFMeta { }; } -llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, bool use_thp, - const llama_model_kv_override * param_overrides_p, - const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) { +llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, + bool repack_tensors, bool use_thp, bool merge_qkv, + const llama_model_kv_override * param_overrides_p, + const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) { int trace = 0; if (getenv("LLAMA_TRACE")) { trace = atoi(getenv("LLAMA_TRACE")); @@ -495,6 +496,7 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, this->check_tensors = check_tensors; this->repack_tensors = repack_tensors; this->use_thp = use_thp; + this->merge_qkv = merge_qkv; } llama_model_loader::~llama_model_loader() { diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h index de57b704..366dea41 100644 --- a/src/llama-model-loader.h +++ b/src/llama-model-loader.h @@ -44,6 +44,7 @@ struct llama_model_loader { bool check_tensors; bool repack_tensors = false; bool use_thp = false; + bool merge_qkv = false; llama_files files; llama_ftype ftype; @@ -78,7 +79,7 @@ struct llama_model_loader { std::string arch_name; LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); - llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, bool use_thp, + llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, bool use_thp, bool merge_qkv, const llama_model_kv_override * param_overrides_p, const llama_model_tensor_buft_override * param_tensor_buft_overrides_p); diff --git a/src/llama-quantize.cpp b/src/llama-quantize.cpp index fa08ba41..14d61971 100644 --- a/src/llama-quantize.cpp +++ b/src/llama-quantize.cpp @@ -1007,7 +1007,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s auto v = (std::vector*)params->kv_overrides; kv_overrides = v->data(); } - llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, /* repack_tensors */ false, /* use_thp */ false, kv_overrides, nullptr); + llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, /* repack_tensors */ false, + /* use_thp */ false, /* merge_qkv */ false, kv_overrides, nullptr); ml.init_mappings(false); // no prefetching llama_model model; diff --git a/src/llama.cpp b/src/llama.cpp index 3fba6574..700b006d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1896,7 +1896,7 @@ static bool llm_load_tensors( static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) { try { llama_model_loader ml(fname, params.use_mmap, params.check_tensors, - params.repack_tensors, params.use_thp, params.kv_overrides, params.tensor_buft_overrides); + params.repack_tensors, params.use_thp, params.merge_qkv, params.kv_overrides, params.tensor_buft_overrides); model.hparams.vocab_only = params.vocab_only; @@ -3788,6 +3788,7 @@ struct llama_model_params llama_model_default_params() { /*.repack_tensors =*/ false, /*.use_thp =*/ false, /*.validate_quants =*/ false, + /*.merge_qkv =*/ false, }; #ifdef GGML_USE_METAL