From c12f73ba6153d162f36434cb48e36dd3649b7701 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 9 Feb 2025 19:48:44 +0200 Subject: [PATCH 001/132] Add optional MLA (#188) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Deepseek MLA Optimizations Co-authored-by: Stanisław Szymczyk * Make MLA optional * Remove some unnecessary copies in the MLA attention * Deepseek MLA Optimizations V2 (#195) * Avoid allocating MHA KV cache when MLA is turned on * Added missing gguf-py file * Added final optimizations Co-authored-by: Stanisław Szymczyk * Make sure we do have wk_b and wv_b before enabling MLA --------- Co-authored-by: Stanisław Szymczyk Co-authored-by: Iwan Kawrakow * Use type_k and type_v to set the types of the MLA caches They were hard-coded at f16. On my Ryzen-7950X with native bf16 support I get a fairly significant PP performance boost with bf16 KV-cache: PP-4096 = 320 t/s up from 292 t/s with fp16 KV-cache. * Better gemm strategy when nth > nhead It gives a ~10% PP performance boost for DeepSeek-Lite with 32 threads (with or without MLA). Before this commit, when nth > nhead heads were processed sequentially with all nth threads participating in each matrix multiplication. Now we ind the gcd of nhead and nth and split threads into nth/gcd groups, each group processing nhead/gcd heads. --------- Co-authored-by: Saood Karim Co-authored-by: Stanisław Szymczyk Co-authored-by: Iwan Kawrakow --- common/common.cpp | 7 + common/common.h | 1 + convert_hf_to_gguf.py | 42 ++++ examples/llama-bench/llama-bench.cpp | 35 ++- ggml/src/ggml.c | 17 +- gguf-py/gguf/constants.py | 6 + gguf-py/gguf/tensor_mapping.py | 8 + include/llama.h | 1 + src/llama.cpp | 338 ++++++++++++++++++++++----- 9 files changed, 380 insertions(+), 75 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 95e91bc1..6219f0ce 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -813,6 +813,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.flash_attn = true; return true; } + if (arg == "-mla" || arg == "--mla-use") { + params.mla_attn = true; + return true; + } if (arg == "-co" || arg == "--color") { params.use_color = true; return true; @@ -1452,6 +1456,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep }); options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks }); options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" }); + options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %s)", params.mla_attn ? "enabled" : "disabled" }); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" "in conversation mode, this will be used as system prompt\n" "(default: '%s')", params.prompt.c_str() }); @@ -2283,6 +2288,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; cparams.flash_attn = params.flash_attn; + cparams.mla_attn = params.mla_attn; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); @@ -3280,6 +3286,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); + fprintf(stream, "mla_attn: %s # default: false\n", params.mla_attn ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); diff --git a/common/common.h b/common/common.h index 73d7d650..fc1ae619 100644 --- a/common/common.h +++ b/common/common.h @@ -174,6 +174,7 @@ struct gpt_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention + bool mla_attn = false; // MLA bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 3910aa1d..1ee82724 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3123,6 +3123,7 @@ class ArcticModel(Model): @Model.register("DeepseekV2ForCausalLM") +@Model.register("DeepseekV3ForCausalLM") class DeepseekV2Model(Model): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 @@ -3144,6 +3145,15 @@ class DeepseekV2Model(Model): self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) + self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + + if hparams["scoring_func"] == "sigmoid": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + elif hparams["scoring_func"] == "softmax": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) + else: + raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}") + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: @@ -3156,6 +3166,17 @@ class DeepseekV2Model(Model): _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # rename e_score_correction_bias tensors + if name.endswith("e_score_correction_bias"): + name = name.replace("e_score_correction_bias", "e_score_correction.bias") + + # skip Multi-Token Prediction (MTP) layers + block_count = self.hparams["num_hidden_layers"] + match = re.match(r"model.layers.(\d+)", name) + if match and int(match.group(1)) >= block_count: + return [] + + # process the experts separately if name.find("mlp.experts") != -1: n_experts = self.hparams["n_routed_experts"] @@ -3188,6 +3209,27 @@ class DeepseekV2Model(Model): return tensors else: return [] + if name.endswith("kv_b_proj.weight"): + name_kb = name.replace("kv_b_proj", "k_b_proj") + name_vb = name.replace("kv_b_proj", "v_b_proj") + + n_head_kv = self.hparams["num_key_value_heads"] + v_head_dim = self.hparams["v_head_dim"] + qk_nope_head_dim = self.hparams["qk_nope_head_dim"] + + assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim) + + kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1]) + k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1) + k_b = k_b.transpose(1, 2) + k_b = k_b.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim) + v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1]) + + return [ + (self.map_tensor_name(name), data_torch), + (self.map_tensor_name(name_kb), k_b), + (self.map_tensor_name(name_vb), v_b) + ] return [(self.map_tensor_name(name), data_torch)] diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 42320da8..41b93df5 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -232,6 +232,7 @@ struct cmd_params { std::vector main_gpu; std::vector no_kv_offload; std::vector flash_attn; + std::vector mla_attn; std::vector> tensor_split; std::vector use_mmap; std::vector embeddings; @@ -261,6 +262,7 @@ static const cmd_params cmd_params_defaults = { /* main_gpu */ {0}, /* no_kv_offload */ {false}, /* flash_attn */ {false}, + /* mla_attn */ {false}, /* tensor_split */ {std::vector(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, @@ -294,6 +296,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -mg, --main-gpu (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); + printf(" -mla, --mla-attn <0|1> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" --numa (default: disabled)\n"); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); @@ -526,6 +529,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = string_split(argv[i], split_delim); params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end()); + } else if (arg == "-mla" || arg == "--mla-attn") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.mla_attn.insert(params.mla_attn.end(), p.begin(), p.end()); } else if (arg == "-mmp" || arg == "--mmap") { if (++i >= argc) { invalid_param = true; @@ -621,6 +631,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; } if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; } if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; } + if (params.mla_attn.empty()) { params.mla_attn = cmd_params_defaults.mla_attn; } if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; } if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; } @@ -656,6 +667,7 @@ struct cmd_params_instance { int main_gpu; bool no_kv_offload; bool flash_attn; + bool mla_attn; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -698,6 +710,7 @@ struct cmd_params_instance { cparams.type_v = type_v; cparams.offload_kqv = !no_kv_offload; cparams.flash_attn = flash_attn; + cparams.mla_attn = mla_attn; cparams.embeddings = embeddings; return cparams; @@ -722,6 +735,7 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & tv : params.type_v) for (const auto & nkvo : params.no_kv_offload) for (const auto & fa : params.flash_attn) + for (const auto & mla : params.mla_attn) for (const auto & nt : params.n_threads) { for (const auto & n_prompt : params.n_prompt) { if (n_prompt == 0) { @@ -743,6 +757,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, + /* .mla_attn = */ mla, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -771,6 +786,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, + /* .mla_attn = */ mla, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -799,6 +815,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, + /* .mla_attn = */ mla, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -827,6 +844,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, + /* .mla_attn = */ mla, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -866,6 +884,7 @@ struct test { int main_gpu; bool no_kv_offload; bool flash_attn; + bool mla_attn; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -895,6 +914,7 @@ struct test { main_gpu = inst.main_gpu; no_kv_offload = inst.no_kv_offload; flash_attn = inst.flash_attn; + mla_attn = inst.mla_attn; tensor_split = inst.tensor_split; use_mmap = inst.use_mmap; embeddings = inst.embeddings; @@ -988,7 +1008,7 @@ struct test { "n_batch", "n_ubatch", "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", - "main_gpu", "no_kv_offload", "flash_attn", + "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "tensor_split", "use_mmap", "embeddings", "repack", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", @@ -1010,7 +1030,7 @@ struct test { } if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || - field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack") { + field == "flash_attn" || field == "mla_attn" || field == "use_mmap" || field == "embeddings" || field == "repack") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -1044,7 +1064,7 @@ struct test { std::to_string(n_batch), std::to_string(n_ubatch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), - std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), + std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), @@ -1208,6 +1228,9 @@ struct markdown_printer : public printer { if (field == "flash_attn") { return 2; } + if (field == "mla_attn") { + return 3; + } if (field == "use_mmap") { return 4; } @@ -1242,6 +1265,9 @@ struct markdown_printer : public printer { if (field == "flash_attn") { return "fa"; } + if (field == "mla_attn") { + return "mla"; + } if (field == "use_mmap") { return "mmap"; } @@ -1294,6 +1320,9 @@ struct markdown_printer : public printer { if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) { fields.emplace_back("flash_attn"); } + if (params.mla_attn.size() > 1 || params.mla_attn != cmd_params_defaults.mla_attn) { + fields.emplace_back("mla_attn"); + } if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) { fields.emplace_back("tensor_split"); } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e07dd547..3867cf00 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -14064,31 +14064,22 @@ static void ggml_compute_forward_mul_mat( #endif #if GGML_USE_IQK_MULMAT - if (dst->type == GGML_TYPE_F32 && (ne12*ne13)%nth == 0) { + if (dst->type == GGML_TYPE_F32) { + int gcd = simple_gcd(ne12*ne13, nth); int counter = 0; for (int64_t i13 = 0; i13 < ne13; i13++) { for (int64_t i12 = 0; i12 < ne12; i12++) { - if (counter++ % nth == ith) { + if ((counter++ % gcd) == (ith%gcd)) { if (!iqk_mul_mat(ne01, ne11, ne00, src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11, ///ggml_type_size(src1->type), (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), - 0, 1)) goto IQK_MulMat_Not_Available1; + ith/gcd, nth/gcd)) goto IQK_MulMat_Not_Available1; } } } return; } - if (dst->type == GGML_TYPE_F32) { - for (int64_t i13 = 0; i13 < ne13; i13++) - for (int64_t i12 = 0; i12 < ne12; i12++) - if (!iqk_mul_mat(ne01, ne11, ne00, - src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), - src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11, ///ggml_type_size(src1->type), - (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), - ith, nth)) goto IQK_MulMat_Not_Available1; - return; - } IQK_MulMat_Not_Available1:; #endif diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 90d5efec..037837da 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -274,6 +274,8 @@ class MODEL_TENSOR(IntEnum): ATTN_Q_B = auto() ATTN_KV_A_MQA = auto() ATTN_KV_B = auto() + ATTN_K_B = auto() + ATTN_V_B = auto() ATTN_Q_A_NORM = auto() ATTN_KV_A_NORM = auto() FFN_SUB_NORM = auto() @@ -403,6 +405,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa", MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", + MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b", + MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b", MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", @@ -967,6 +971,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_Q_B, MODEL_TENSOR.ATTN_KV_A_MQA, MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_K_B, + MODEL_TENSOR.ATTN_V_B, MODEL_TENSOR.ATTN_Q_A_NORM, MODEL_TENSOR.ATTN_KV_A_NORM, MODEL_TENSOR.ATTN_OUT, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index a70b69c5..e8725426 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -446,6 +446,14 @@ class TensorNameMap: "model.layers.{bid}.self_attn.kv_b_proj", # deepseek2 ), + MODEL_TENSOR.ATTN_K_B: ( + "model.layers.{bid}.self_attn.k_b_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_V_B: ( + "model.layers.{bid}.self_attn.v_b_proj", # deepseek2 + ), + MODEL_TENSOR.ATTN_Q_A_NORM: ( "model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2 ), diff --git a/include/llama.h b/include/llama.h index 730c087a..39251d35 100644 --- a/include/llama.h +++ b/include/llama.h @@ -374,6 +374,7 @@ extern "C" { bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool flash_attn; // whether to use flash attention [EXPERIMENTAL] + bool mla_attn; // whether to use MLA attention [EXPERIMENTAL] // Abort callback // if it returns true, execution of llama_decode() will be aborted diff --git a/src/llama.cpp b/src/llama.cpp index 29926a94..00e6c934 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -539,6 +539,8 @@ enum llm_tensor { LLM_TENSOR_ATTN_Q_B, LLM_TENSOR_ATTN_KV_A_MQA, LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_KV_A_NORM, LLM_TENSOR_ATTN_SUB_NORM, @@ -1203,6 +1205,8 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, + { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, @@ -2503,6 +2507,7 @@ struct llama_cparams { bool causal_attn; bool offload_kqv; bool flash_attn; + bool mla_attn; enum llama_pooling_type pooling_type; @@ -2541,6 +2546,8 @@ struct llama_layer { struct ggml_tensor * wq_b; struct ggml_tensor * wkv_a_mqa; struct ggml_tensor * wkv_b; + struct ggml_tensor * wk_b; + struct ggml_tensor * wv_b; struct ggml_tensor * wq_cross; struct ggml_tensor * wk_cross; struct ggml_tensor * wv_cross; @@ -2669,11 +2676,19 @@ struct llama_kv_cache { ggml_type type_k = GGML_TYPE_F16; ggml_type type_v = GGML_TYPE_F16; + ggml_type type_kr = GGML_TYPE_F16; + ggml_type type_kv = GGML_TYPE_F16; + std::vector cells; std::vector k_l; // per layer std::vector v_l; + // DeepSeek MLA + std::vector kr_l; // per layer + std::vector kv_l; + std::vector kvt_l; + std::vector ctxs; std::vector bufs; @@ -3104,8 +3119,10 @@ static bool llama_kv_cache_init( cache.size = kv_size; cache.used = 0; - cache.type_k = type_k; - cache.type_v = type_v; + cache.type_k = type_k; + cache.type_v = type_v; + cache.type_kr = type_k; + cache.type_kv = type_v; cache.cells.clear(); cache.cells.resize(kv_size); @@ -3132,7 +3149,7 @@ static bool llama_kv_cache_init( for (auto & it : buft_layer_count) { int n_layers = it.second; struct ggml_init_params params = { - /*.mem_size =*/ 2u*n_layers*ggml_tensor_overhead(), + /*.mem_size =*/ 5u*n_layers*ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -3148,17 +3165,46 @@ static bool llama_kv_cache_init( cache.k_l.reserve(n_layer); cache.v_l.reserve(n_layer); + // DeepSeek MLA + cache.kr_l.reserve(n_layer); + cache.kv_l.reserve(n_layer); + cache.kvt_l.reserve(n_layer); + for (int i = 0; i < (int) n_layer; i++) { const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_tensor * k; + ggml_tensor * v; + if (cparams.mla_attn && model.layers[i].wk_b && model.layers[i].wv_b) { + k = ggml_new_tensor_1d(ctx, type_k, 1); + v = ggml_new_tensor_1d(ctx, type_v, 1); + } + else { + k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + } + ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); cache.k_l.push_back(k); cache.v_l.push_back(v); + + + // DeepSeek MLA + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); + ggml_tensor * kr = ggml_new_tensor_1d(ctx, cache.type_kr, n_embd_head_qk_rope*kv_size); + ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size); + ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size); + ggml_format_name(kr, "cache_kr_l%d", i); + ggml_format_name(kv, "cache_kv_l%d", i); + ggml_format_name(kvt, "cache_kvt_l%d", i); + cache.kr_l.push_back(kr); + cache.kv_l.push_back(kv); + cache.kvt_l.push_back(kvt); } // allocate tensors and initialize the buffers to avoid NaNs in the padding @@ -7644,6 +7690,8 @@ static bool llm_load_tensors( layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}); layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}); + layer.wk_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 1); + layer.wv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v}, 1); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); @@ -8815,6 +8863,7 @@ struct llm_build_context { const int32_t n_ctx_orig; const bool flash_attn; + const bool mla_attn; const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; @@ -8864,6 +8913,7 @@ struct llm_build_context { kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), + mla_attn (cparams.mla_attn), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), @@ -13329,6 +13379,10 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + // whether to use n_tokens as the matrix dimension during multiplication or n_head + // n_tokens is higher during prompt processing, this allows to optimize for this case + bool pp_opt = n_tokens > n_head; + for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -13383,71 +13437,216 @@ struct llm_build_context { 0); cb(kv_compressed, "kv_compressed", il); - // and {n_embd_head_qk_rope, n_tokens} - struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, - kv_pe_compresseed->nb[1], - kv_pe_compresseed->nb[1], - ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); - cb(k_pe, "k_pe", il); + if (lctx.cparams.mla_attn && model.layers[il].wk_b && model.layers[il].wv_b) { - //kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm - kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams, - model.layers[il].attn_kv_a_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(kv_compressed, "kv_compressed", il); + // and {n_embd_head_qk_rope, n_tokens} + struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, + kv_pe_compresseed->nb[1], + kv_pe_compresseed->nb[1], + ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); + cb(k_pe, "k_pe", il); - // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} - struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); - cb(kv, "kv", il); + //kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm + kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams, + model.layers[il].attn_kv_a_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(kv_compressed, "kv_compressed", il); - // split into {n_head * n_embd_head_qk_nope, n_tokens} - struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), - 0); - cb(k_nope, "k_nope", il); + struct ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)*kv_head); + cb(kv_cache_view, "kv_cache_view", il); - // and {n_head * n_embd_head_v, n_tokens} - struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), - ggml_row_size(kv->type, (n_embd_head_qk_nope))); - cb(v_states, "v_states", il); + // note: storing c^KV in the KV cache + ggml_build_forward_expand(gf, ggml_cpy(ctx0, kv_compressed, kv_cache_view)); - v_states = ggml_cont(ctx0, v_states); - cb(v_states, "v_states", il); + struct ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), ggml_row_size(kv_self.kv_l[il]->type, kv_head)); + cb(kv_cache_trans_view, "kv_cache_trans_view", il); - v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, - ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), - 0); - cb(v_states, "v_states", il); + // note: storing transposed c^KV in the transposed KV cache + ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view)); - //q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE - q_pe = ggml_rope_ext( - ctx0, q_pe, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); - cb(q_pe, "q_pe", il); + struct ggml_tensor * kv_cache = + ggml_view_2d(ctx0, kv_self.kv_l[il], + kv_lora_rank, n_kv, + ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank), + 0); + cb(kv_cache, "kv_cache", il); - // shared RoPE key - //k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE - k_pe = ggml_rope_ext( - ctx0, k_pe, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); - cb(k_pe, "k_pe", il); + struct ggml_tensor * kv_cache_trans = + ggml_view_2d(ctx0, kv_self.kvt_l[il], + n_kv, kv_lora_rank, + ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), + 0); + cb(kv_cache_trans, "kv_cache_trans", il); - struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); - cb(q_states, "q_states", il); + //q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE + q_pe = ggml_rope_ext( + ctx0, q_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(q_pe, "q_pe", il); - struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); - cb(k_states, "k_states", il); + // shared RoPE key + //k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE + k_pe = ggml_rope_ext( + ctx0, k_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(k_pe, "k_pe", il); - cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, NULL, - k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + struct ggml_tensor * kr_cache_view = ggml_view_1d(ctx0, kv_self.kr_l[il], n_tokens*n_embd_head_qk_rope, ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope)*kv_head); + cb(kr_cache_view, "kr_cache_view", il); + + // note: storing RoPE-ed version of K^R in the KV cache + ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_pe, kr_cache_view)); + + struct ggml_tensor * kr_cache = + ggml_view_2d(ctx0, kv_self.kr_l[il], + n_embd_head_qk_rope, n_kv, + ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope), + 0); + cb(kr_cache, "kr_cache", il); + + struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), 0); + cb(wk_b, "wk_b", il); + + q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); + cb(q_nope, "q_nope_perm", il); + + struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope); + cb(q_nope2, "q_nope2", il); + + if (!pp_opt) { + q_nope2 = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3); + cb(q_nope2, "q_nope2_perm", il); + } + struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2); + cb(kq_nope, "kq_nope", il); + + if (!pp_opt) { + kq_nope = ggml_permute(ctx0, kq_nope, 0, 2, 1, 3); + cb(kq_nope, "kq_nope_perm", il); + } + + if (pp_opt) { + q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3); + cb(q_pe, "q_pe_perm", il); + } + struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe); + cb(kq_pe, "kq_pe", il); + + if (!pp_opt) { + kq_pe = ggml_permute(ctx0, kq_pe, 0, 2, 1, 3); + cb(kq_pe, "kq_pe_perm", il); + } + + struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe); + cb(kq, "kq", il); + + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + if (!pp_opt) { + kq = ggml_permute(ctx0, kq, 0, 2, 1, 3); + cb(kq, "kq_soft_max_ext_perm", il); + } + + struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq); + cb(kqv_compressed, "kqv_compressed", il); + + if (!pp_opt) { + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); + cb(kqv_compressed, "kqv_compressed_perm", il); + } + + struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0); + cb(wv_b, "wv_b", il); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed); + cb(kqv, "kqv", il); + + kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3)); + cb(kqv, "kqv_perm", il); + + cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0); + cb(cur, "kqv_2d", il); + + ggml_build_forward_expand(gf, cur); + + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); + cb(cur, "kqv_out", il); + + } + else { + + // and {n_embd_head_qk_rope, n_tokens} + struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, + kv_pe_compresseed->nb[1], + kv_pe_compresseed->nb[1], + ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); + cb(k_pe, "k_pe", il); + + //kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm + kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams, + model.layers[il].attn_kv_a_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(kv_compressed, "kv_compressed", il); + + // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} + struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); + cb(kv, "kv", il); + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + 0); + cb(k_nope, "k_nope", il); + + // and {n_head * n_embd_head_v, n_tokens} + struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), + ggml_row_size(kv->type, (n_embd_head_qk_nope))); + cb(v_states, "v_states", il); + + v_states = ggml_cont(ctx0, v_states); + cb(v_states, "v_states", il); + + v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, + ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), + 0); + cb(v_states, "v_states", il); + + //q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE + q_pe = ggml_rope_ext( + ctx0, q_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(q_pe, "q_pe", il); + + // shared RoPE key + //k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE + k_pe = ggml_rope_ext( + ctx0, k_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(k_pe, "k_pe", il); + + struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); + cb(q_states, "q_states", il); + + struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); + cb(k_states, "k_states", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, NULL, + k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + + } } if (il == n_layer - 1) { @@ -17391,6 +17590,7 @@ struct llama_context_params llama_context_default_params() { /*.embeddings =*/ false, /*.offload_kqv =*/ true, /*.flash_attn =*/ false, + /*.mla_attn =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -17589,6 +17789,7 @@ struct llama_context * llama_new_context_with_model( cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; cparams.flash_attn = params.flash_attn; + cparams.mla_attn = params.mla_attn; cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; @@ -17655,6 +17856,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); + LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -17853,6 +18055,24 @@ struct llama_context * llama_new_context_with_model( ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } + { + size_t memory_size_kr = 0; + size_t memory_size_kv = 0; + + for (auto & kr : ctx->kv_self.kr_l) { + memory_size_kr += ggml_nbytes(kr); + } + + for (auto & kv : ctx->kv_self.kv_l) { + memory_size_kv += ggml_nbytes(kv); + } + + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K^R (%s): %7.2f MiB, c^KV (%s): %7.2f MiB\n", __func__, + (float)(memory_size_kr + memory_size_kv) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_kr / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_kv / (1024.0f * 1024.0f)); + } + // graph outputs buffer { // resized during inference when a batch uses more outputs From a366a3d17d8f2de0eb8c3d9eddc7b5840fb5761a Mon Sep 17 00:00:00 2001 From: saood06 Date: Mon, 10 Feb 2025 09:40:38 -0600 Subject: [PATCH 002/132] Load all MoE experts during warmup and make warmup 1 token (#198) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Load all MoE experts during warmup Co-authored-by: Stanisław Szymczyk * Unify warmup to one token --------- Co-authored-by: Stanisław Szymczyk --- common/common.cpp | 6 ++++-- examples/llama-bench/llama-bench.cpp | 2 +- src/llama.cpp | 19 ++++++++++++------- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 6219f0ce..44678d7a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2169,8 +2169,10 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { if (bos != -1) { tmp.push_back(bos); } - tmp.push_back(eos); - + else + { + tmp.push_back(eos); + } if (llama_model_has_encoder(model)) { llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0)); llama_token decoder_start_token_id = llama_model_decoder_start_token(model); diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 41b93df5..95df06dc 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1586,7 +1586,7 @@ int main(int argc, char ** argv) { if (params.warmup) { if (t.n_prompt > 0) { //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads); - test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads); + test_prompt(ctx, 1, 0, t.n_batch, t.n_threads); } if (t.n_gen > 0) { test_gen(ctx, 1, 0, t.n_threads); diff --git a/src/llama.cpp b/src/llama.cpp index 00e6c934..b2553802 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3784,7 +3784,7 @@ static size_t llama_model_max_nodes(const llama_model & /*model*/) { // return 32768; //} - return 8192; + return 65536; } struct llama_model_loader { @@ -8879,7 +8879,8 @@ struct llm_build_context { llama_context & lctx, const llama_batch & batch, const llm_build_cb & cb, - bool worst_case) : + bool worst_case, + bool warmup) : model (lctx.model), lctx (lctx), hparams (model.hparams), @@ -8897,7 +8898,7 @@ struct llm_build_context { n_embd_head_v (hparams.n_embd_head_v), n_embd_v_gqa (hparams.n_embd_v_gqa()), n_expert (hparams.n_expert), - n_expert_used (hparams.n_expert_used), + n_expert_used (warmup ? hparams.n_expert : hparams.n_expert_used), freq_base (cparams.rope_freq_base), freq_scale (cparams.rope_freq_scale), ext_factor (cparams.yarn_ext_factor), @@ -14433,7 +14434,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; - struct llm_build_context llm(lctx, dummy, cb, false); + struct llm_build_context llm(lctx, dummy, cb, false, false); llm.init(); @@ -14450,7 +14451,7 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; - struct llm_build_context llm(lctx, dummy, cb, false); + struct llm_build_context llm(lctx, dummy, cb, false, false); llm.init(); @@ -14467,7 +14468,7 @@ static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) { llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; - struct llm_build_context llm(lctx, dummy, cb, false); + struct llm_build_context llm(lctx, dummy, cb, false, false); llm.init(); @@ -14517,7 +14518,11 @@ static struct ggml_cgraph * llama_build_graph( struct ggml_cgraph * result = NULL; - struct llm_build_context llm(lctx, batch, cb, worst_case); + const llama_vocab * vocab = llama_get_vocab(&lctx); + llama_token bos = llama_token_bos_impl(*vocab); + llama_token eos = llama_token_eos_impl(*vocab); + bool is_warming_up = (batch.n_tokens == 1 && batch.token[0] == bos); + struct llm_build_context llm(lctx, batch, cb, worst_case, is_warming_up); llm.init(); From 3c98bfb33d149a0d9d3bb91604dd12709721e3cf Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 11 Feb 2025 14:46:30 +0200 Subject: [PATCH 003/132] DeepSeek FA support (CPU only) (#200) * Adding support for K head size != V head size This is relevant for DeepSeek models. At this point ggml CPU FA works. Now I need to go and change iqk FA to make it work with Dk != Dv. * iqk support for K head size != V head size To not have compilation time explode, just Dk = 192, Dv = 128 for now (DeepSeek) * FA: very slightly faster for nq = 1 (TG) --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml.c | 61 ++++---- ggml/src/iqk/iqk_mul_mat.cpp | 278 +++++++++++++++++++++++------------ ggml/src/iqk/iqk_mul_mat.h | 3 +- src/llama.cpp | 8 +- 4 files changed, 221 insertions(+), 129 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 3867cf00..7b631177 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -8473,8 +8473,12 @@ struct ggml_tensor * ggml_flash_attn_ext( is_node = true; } + // k*q will be { k->ne[1], q->ne[2], q->ne[1], q->ne[3] } + // v^T is { v->ne[1], v->ne[0], v->ne[2], v->ne[3] } + // => result is { v->ne[0], q->ne[2], q->ne[1], q->ne[3] } // permute(0, 2, 1, 3) - int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + //int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); float params[] = { scale, max_bias, softcap }; @@ -17436,10 +17440,11 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int ith = params->ith; const int nth = params->nth; - const int64_t D = neq0; - const int64_t N = neq1; + const int64_t Dk = nek0; + const int64_t Dv = nev0; + const int64_t N = neq1; - GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne0 == Dv); GGML_ASSERT(ne2 == N); // input tensor rows must be contiguous @@ -17447,12 +17452,12 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(nbk0 == ggml_type_size(k->type)); GGML_ASSERT(nbv0 == ggml_type_size(v->type)); - GGML_ASSERT(neq0 == D); - GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev0 == D); + GGML_ASSERT(neq0 == Dk); + GGML_ASSERT(nek0 == Dk); + GGML_ASSERT(nev0 == Dv); GGML_ASSERT(neq1 == N); - GGML_ASSERT(nev0 == D); + GGML_ASSERT(nev0 == Dv); // dst cannot be transposed or permuted GGML_ASSERT(nb0 == sizeof(float)); @@ -17516,7 +17521,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( int iq1 = (ith%ntg)*neq1g; int this_neq1 = MIN(neq1g, neq1-iq1); if (!iqk_flash_attn_noalibi(k->type, v->type, - D, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), + Dk, Dv, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]), (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]), (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]), @@ -17543,6 +17548,8 @@ IQK_Flash_Attn_NotAvailable:; ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot; ggml_to_float_t const v_to_float = type_traits[v->type].to_float; + const int64_t Dkv = MAX(Dk, Dv); + // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices @@ -17556,15 +17563,15 @@ IQK_Flash_Attn_NotAvailable:; float S = 0.0f; // sum float M = -INFINITY; // maximum KQ value - float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator - float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer - ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator - ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16 + float * VKQ32 = (float *) params->wdata + ith*(3*Dkv + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator + float * V32 = (VKQ32 + 1*Dkv); // (temporary) FP32 V buffer + ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*Dkv); // (temporary) FP16 VKQ accumulator + ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*Dkv); // (temporary) buffer for Q converted to quantized/FP16 if (v->type == GGML_TYPE_F16) { - memset(VKQ16, 0, D*sizeof(ggml_fp16_t)); + memset(VKQ16, 0, Dkv*sizeof(ggml_fp16_t)); } else { - memset(VKQ32, 0, D*sizeof(float)); + memset(VKQ32, 0, Dkv*sizeof(float)); } const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; @@ -17578,7 +17585,7 @@ IQK_Flash_Attn_NotAvailable:; const int iv2 = iq2 / rv2; const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); - q_to_vec_dot(pq, Q_q, D); + q_to_vec_dot(pq, Q_q, Dk); // online softmax / attention // loop over n_kv and n_head_kv @@ -17592,7 +17599,7 @@ IQK_Flash_Attn_NotAvailable:; float s; // KQ value const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); - kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1); + kq_vec_dot(Dk, &s, 0, k_data, 0, Q_q, 0, 1); s = softcap == 0.0f ? s*scale + mv : softcap*tanhf(s*scale) + mv; // scale KQ value and apply mask @@ -17610,14 +17617,14 @@ IQK_Flash_Attn_NotAvailable:; ms = expf(Mold - M); // V = V*expf(Mold - M) - ggml_vec_scale_f16(D, VKQ16, ms); + ggml_vec_scale_f16(Dv, VKQ16, ms); } else { // no new maximum, ms == 1.0f, vs != 1.0f vs = expf(s - M); } // V += v*expf(s - M) - ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs); + ggml_vec_mad_f16(Dv, VKQ16, (const ggml_fp16_t *) v_data, vs); } else { if (s > M) { // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f @@ -17625,30 +17632,30 @@ IQK_Flash_Attn_NotAvailable:; ms = expf(Mold - M); // V = V*expf(Mold - M) - ggml_vec_scale_f32(D, VKQ32, ms); + ggml_vec_scale_f32(Dv, VKQ32, ms); } else { // no new maximum, ms == 1.0f, vs != 1.0f vs = expf(s - M); } - v_to_float(v_data, V32, D); + v_to_float(v_data, V32, Dv); // V += v*expf(s - M) - ggml_vec_mad_f32(D, VKQ32, V32, vs); + ggml_vec_mad_f32(Dv, VKQ32, V32, vs); } S = S*ms + vs; // scale and increment sum with partial sum } if (v->type == GGML_TYPE_F16) { - for (int64_t d = 0; d < D; ++d) { + for (int64_t d = 0; d < Dv; ++d) { VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]); } } // V /= S const float S_inv = 1.0f/S; - ggml_vec_scale_f32(D, VKQ32, S_inv); + ggml_vec_scale_f32(Dv, VKQ32, S_inv); // dst indices const int i1 = iq1; @@ -21112,9 +21119,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa } break; case GGML_OP_FLASH_ATTN_EXT: { - const int64_t ne00 = node->src[0]->ne[0]; // D + const int64_t Dk = node->src[0]->ne[0]; + const int64_t Dv = node->src[2]->ne[0]; + const int64_t D = MAX(Dk, Dv); - cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread + cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread } break; case GGML_OP_FLASH_ATTN_BACK: { diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index ee0af7e9..3b58495e 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -14879,10 +14879,60 @@ struct FlashQKV { using qkv_cache_t = float; #endif + template + inline void accumulate_qkv_1(const VHelper& vh, const FlashMS& fms) { + F16::Data vq[D/F16::block_size]; + if (fms.need_scaling[0] == 2) { + for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::zero(); + } else { + for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::load(qkv_cache + F16::block_size*i); + if (fms.need_scaling[0] == 1) { + auto vms = F16::set1(fms.vms[0]); + for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::mul(vms, vq[i]); + } + } + //F16::Data v[8]; + F16::Data v0, v1; + for (int l = 0; l < k_step; l += 4) { + auto vs0 = F16::set1(fms.cache[l + 0]); + auto vs1 = F16::set1(fms.cache[l + 1]); + auto vs2 = F16::set1(fms.cache[l + 2]); + auto vs3 = F16::set1(fms.cache[l + 3]); + //auto vs = F16::set4(fms.cache + l); + for (int i = 0; i < D/F16::block_size; i += 2) { + vh.load(l+0, i, v0, v1); + vq[i+0] = F16::fmadd(vq[i+0], v0, vs0); + vq[i+1] = F16::fmadd(vq[i+1], v1, vs0); + vh.load(l+1, i, v0, v1); + vq[i+0] = F16::fmadd(vq[i+0], v0, vs1); + vq[i+1] = F16::fmadd(vq[i+1], v1, vs1); + vh.load(l+2, i, v0, v1); + vq[i+0] = F16::fmadd(vq[i+0], v0, vs2); + vq[i+1] = F16::fmadd(vq[i+1], v1, vs2); + vh.load(l+3, i, v0, v1); + vq[i+0] = F16::fmadd(vq[i+0], v0, vs3); + vq[i+1] = F16::fmadd(vq[i+1], v1, vs3); + //vq[i+0] = F16::fmadd_lane0(vq[i+0], v[0], vs); + //vq[i+1] = F16::fmadd_lane0(vq[i+1], v[4], vs); + //vq[i+0] = F16::fmadd_lane1(vq[i+0], v[1], vs); + //vq[i+1] = F16::fmadd_lane1(vq[i+1], v[5], vs); + //vq[i+0] = F16::fmadd_lane2(vq[i+0], v[2], vs); + //vq[i+1] = F16::fmadd_lane2(vq[i+1], v[6], vs); + //vq[i+0] = F16::fmadd_lane3(vq[i+0], v[3], vs); + //vq[i+1] = F16::fmadd_lane3(vq[i+1], v[7], vs); + } + } + for (int i = 0; i < D/F16::block_size; ++i) F16::store(qkv_cache + F16::block_size*i, vq[i]); + } + // This fails for head sizes of 80 and 112 as D/16 is odd, so we cannot do steps of 2 // Hence, for now, we will not handle head sizes of 80 and 112 template inline void accumulate_qkv(const VHelper& vh, const FlashMS& fms) { + if constexpr (q_step == 1) { + accumulate_qkv_1(vh, fms); + return; + } F16::Data v[8]; for (int j = 0; j < q_step; ++j) { auto R = qkv_cache + D*j; @@ -14924,6 +14974,10 @@ struct FlashQKV { template inline void accumulate_qkv(int nq1, const VHelper& vh, const FlashMS& fms) { + if (nq1 == 1) { + accumulate_qkv_1(vh, fms); + return; + } F16::Data v[8]; for (int j = 0; j < nq1; ++j) { auto R = qkv_cache + D*j; @@ -15346,13 +15400,13 @@ struct FlashQKfp32 { } }; -template +template void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, FlashMS& fms, - FlashQKV& fqkv, + FlashQKV& fqkv, const float * q, const char * mask, float * qkv) { #ifdef __aarch64__ - float16_t q_f16[D*q_step]; + float16_t q_f16[Dk*q_step]; #endif for (int i1 = 0; i1 < nq1/q_step; ++i1) { @@ -15365,7 +15419,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { #ifdef __aarch64__ - KQHelper::multiply_mask_kq(kh, D, stride_m, q_f16, mr, fms); + KQHelper::multiply_mask_kq(kh, Dk, stride_m, q_f16, mr, fms); #else KQHelper::multiply_mask_kq(kh, stride_q, stride_m, q, mr, fms); #endif @@ -15391,7 +15445,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { #ifdef __aarch64__ - KQHelper::multiply_mask_kq(n_left, kh, D, stride_m, q_f16, mr, fms); + KQHelper::multiply_mask_kq(n_left, kh, Dk, stride_m, q_f16, mr, fms); #else KQHelper::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms); #endif @@ -15404,12 +15458,12 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in } } -template +template void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, FlashMS& fms, - FlashQKV& fqkv, + FlashQKV& fqkv, const float * q, const char * mask, float * qkv) { - typename KHelper::block_q8 q8[q_step*(D/QK8_0)]; + typename KHelper::block_q8 q8[q_step*(Dk/QK8_0)]; #if FA_TIMING Perf perf(false); #endif @@ -15420,7 +15474,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, fms.init_qstep(); kh.reset_block(); vh.reset_block(); - HelperQ80::convert(q_step, stride_q, q, q8); + HelperQ80::convert(q_step, stride_q, q, q8); #if FA_TIMING perf.accum_nolock(0, t1); #endif @@ -15458,7 +15512,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, fms.init_qstep(); kh.reset_block(); vh.reset_block(); - HelperQ80::convert(n_left, stride_q, q, q8); + HelperQ80::convert(n_left, stride_q, q, q8); auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { KQHelper::mul_mask_kq(n_left, kh, stride_m, q8, mr, fms); @@ -15484,9 +15538,10 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, // rows (if Nq is not a multiple of q_step). One could have made the number of q^T rows to // process template parameter of such functions, but this would result in the compiler generating // q_step-1 versions of these functions for us, which I though was too much with q_step = 8. -template +template struct FlashAttn { - static_assert(D%F16::block_size == 0 && D <= 256); + static_assert(Dk%F16::block_size == 0 && Dk <= 256); + static_assert(Dv%F16::block_size == 0 && Dv <= 256); static_assert(k_step%F16::block_size == 0); static_assert(q_step <= 4 || q_step%4 == 0); @@ -15495,35 +15550,35 @@ struct FlashAttn { template void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv) { - if constexpr (std::is_same_v> || std::is_same_v> || - std::is_same_v> || - std::is_same_v>) { - compute_helper_q>( + if constexpr (std::is_same_v> || std::is_same_v> || + std::is_same_v> || + std::is_same_v>) { + compute_helper_q>( kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); } - else if constexpr (std::is_same_v>) { + else if constexpr (std::is_same_v>) { if (nq1 >= 8) { #if FA_TIMING auto t1 = Perf::cur_time(); - HelperQ80R4 khr4(nk1, kh); + HelperQ80R4 khr4(nk1, kh); Perf::instance().accum(4, t1); #else - HelperQ80R4 khr4(nk1, kh); + HelperQ80R4 khr4(nk1, kh); #endif - compute_helper_q, VHelper, FlashQKfp32>( + compute_helper_q, VHelper, FlashQKfp32>( khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); } else{ - compute_helper_q>( + compute_helper_q>( kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); } } else { - compute_helper>( + compute_helper>( kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); } } - FlashMS fms; - FlashQKV fqkv; + FlashMS fms; + FlashQKV fqkv; }; @@ -15756,7 +15811,22 @@ struct FlashQKbf16 { static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q, const char * mask, FlashMS& fms) { #endif - { + if constexpr (q_step == 1) { + __m512bh vq[D/32]; + __m512bh vk[D/32]; + __m256 sum[8]; + for (int i = 0; i < D/32; ++i) vq[i] = __m512bh(_mm512_loadu_si512((const __m512i *)q + i)); + for (int l = 0; l < k_step; l += 8) { + for (int k = 0; k < 8; ++k) { + kh.load(l+k, vk); + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vk[i], vq[i]); + sum[k] = _mm256_add_ps(_mm512_castps512_ps256(vsum), _mm512_extractf32x8_ps(vsum, 1)); + } + _mm256_storeu_ps(fms.cache + l, hsum_float_8x8(sum)); + } + } + else { __m512bh qv[D/32]; if constexpr (D <= 128) { __m512bh vkh[D/4]; @@ -15856,9 +15926,10 @@ struct FlashQKbf16 { } }; -template +template struct FlashAttnBF16 { - static_assert(D%32 == 0 && D <= 256); + static_assert(Dk%32 == 0 && Dk <= 256); + static_assert(Dv%32 == 0 && Dv <= 256); static_assert(k_step%32 == 0); static_assert(q_step <= 4 || q_step%4 == 0); @@ -15867,7 +15938,7 @@ struct FlashAttnBF16 { template void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv) { - ggml_bf16_t q_bf16[q_step*D]; + ggml_bf16_t q_bf16[q_step*Dk]; #if FA_TIMING Perf perf(false); #endif @@ -15878,7 +15949,7 @@ struct FlashAttnBF16 { fms.init_qstep(); kh.reset_block(); vh.reset_block(); - FlashQKbf16::convert(stride_q, q, q_bf16); + FlashQKbf16::convert(stride_q, q, q_bf16); #if FA_TIMING perf.accum_nolock(0, t1); #endif @@ -15886,13 +15957,13 @@ struct FlashAttnBF16 { for (int k1 = 0; k1 < nk1/k_step; ++k1) { #if FA_TIMING //t1 = Perf::cur_time(); - FlashQKbf16::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf); + FlashQKbf16::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf); //perf.accum_nolock(1, t1); t1 = Perf::cur_time(); fqkv.accumulate_qkv(vh, fms); perf.accum_nolock(3, t1); #else - FlashQKbf16::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms); + FlashQKbf16::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms); fqkv.accumulate_qkv(vh, fms); #endif kh.next_block(); @@ -15916,10 +15987,10 @@ struct FlashAttnBF16 { fms.init_qstep(); kh.reset_block(); vh.reset_block(); - FlashQKbf16::convert(n_left, stride_q, q, q_bf16); + FlashQKbf16::convert(n_left, stride_q, q, q_bf16); auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { - FlashQKbf16::multiply_mask_kq(n_left, kh, stride_m, q_bf16, mr, fms); + FlashQKbf16::multiply_mask_kq(n_left, kh, stride_m, q_bf16, mr, fms); fqkv.accumulate_qkv(n_left, vh, fms); kh.next_block(); vh.next_block(); @@ -15932,72 +16003,72 @@ struct FlashAttnBF16 { #endif } - FlashMS fms; - FlashQKV fqkv; + FlashMS fms; + FlashQKV fqkv; }; #endif -template +template inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float scale, float softcap, float * qkv) { if (nk1 >= 256) { //4096) { if (nq1 >= 64) { - FlashAttn fa(scale, softcap); + FlashAttn fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); return; } if (nq1 >= 32) { - FlashAttn fa(scale, softcap); + FlashAttn fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); return; } if (nq1 >= 16) { - FlashAttn fa(scale, softcap); + FlashAttn fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); return; } } if (nq1 >= 8) { - FlashAttn fa(scale, softcap); + FlashAttn fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); } else { - FlashAttn fa(scale, softcap); + FlashAttn fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); } } #ifdef __AVX512BF16__ -template +template inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, float scale, float softcap, float * qkv) { - HelperBF16 kh(k, stride_k); - HelperBF16 vh(v, stride_v); + HelperBF16 kh(k, stride_k); + HelperBF16 vh(v, stride_v); if (nk1 >= 4096) { if (nq1 >= 64) { - FlashAttnBF16 fa(scale, softcap); + FlashAttnBF16 fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); return; } else if (nq1 >= 16) { - FlashAttnBF16 fa(scale, softcap); + FlashAttnBF16 fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); return; } } if (nq1 >= 8) { - FlashAttnBF16 fa(scale, softcap); + FlashAttnBF16 fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); } else { - FlashAttnBF16 fa(scale, softcap); + FlashAttnBF16 fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); } } #endif -template +template inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv, const float * q, const char * v, const char * mask, @@ -16005,42 +16076,42 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, switch (type_v) { case GGML_TYPE_F16: { - HelperF16 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperF16 vh(v, stride_v); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; #ifdef HAVE_FANCY_SIMD case GGML_TYPE_BF16: { - HelperBF16 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperBF16 vh(v, stride_v); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; #endif case GGML_TYPE_Q8_0: { - HelperQ80 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperQ80 vh(v, stride_v); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q6_0: { - HelperQ60 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperQ60 vh(v, stride_v); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; #if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { - HelperQ40 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperQ40 vh(v, stride_v); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q4_1: { - HelperQ41 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperQ41 vh(v, stride_v); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; case GGML_TYPE_IQ4_NL: { - HelperIQ4nl vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + HelperIQ4nl vh(v, stride_v); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; #endif default: break; } } -template +template inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, @@ -16048,29 +16119,29 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, switch (type_k) { case GGML_TYPE_F16: { - HelperF16 kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + HelperF16 kh(k, stride_k); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q8_0: { - HelperQ80 kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + HelperQ80 kh(k, stride_k); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q6_0: { - HelperQ60 kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + HelperQ60 kh(k, stride_k); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; #if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { - HelperQ40 kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + HelperQ40 kh(k, stride_k); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q4_1: { - HelperQ41 kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + HelperQ41 kh(k, stride_k); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; case GGML_TYPE_IQ4_NL: { - HelperIQ4nl kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + HelperIQ4nl kh(k, stride_k); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; #endif default: break; @@ -16094,7 +16165,8 @@ inline bool flash_attn_is_supported(ggml_type type) { bool iqk_flash_attn_noalibi(int int_type_k, // type of k int int_type_v, // type of v - int D, // head size + int Dk, // K head size + int Dv, // V head size int nq1, // number of columns in q int nk1, // number of rows in k int stride_q, // distance between q columns in bytes @@ -16114,7 +16186,9 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k auto type_v = ggml_type(int_type_v); if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false; if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32 - if (D != 64 && D != 96 && D != 128 && D != 256) return false; + if (Dk != Dv && Dk != 192 && Dv != 128) return false; + if (Dv != 64 && Dv != 96 && Dv != 128 && Dv != 256) return false; + if (Dk != 64 && Dk != 96 && Dk != 128 && Dk != 192 && Dv != 256) return false; auto ck = (const char *)k; auto cv = (const char *)v; @@ -16126,30 +16200,34 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k if (type_k == GGML_TYPE_BF16) { if (nk1%64 == 0) { if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types - switch (D) { + switch (Dk) { case 64: - iqk_flash_helper_T< 64, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: - iqk_flash_helper_T<128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 192: + iqk_flash_helper_T<192, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 256: - iqk_flash_helper_T<256, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; default: return false; } return true; } if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types - switch (D) { + switch (Dk) { case 64: - iqk_flash_helper_T< 64, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: - iqk_flash_helper_T<128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 192: + iqk_flash_helper_T<192, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 256: - iqk_flash_helper_T<256, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; default: return false; } @@ -16159,41 +16237,45 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k #endif if (nk1%64 == 0) { - switch (D) { + switch (Dk) { case 64: - iqk_flash_helper_T< 64, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; // Disable until we fix accumulate_qkv for odd D/16 //case 80: // iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; // Disable until we fix accumulate_qkv for odd D/16 //case 112: // iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: - iqk_flash_helper_T<128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 192: + iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 256: - iqk_flash_helper_T<256, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; default: return false; } return true; } - switch (D) { + switch (Dk) { case 64: - iqk_flash_helper_T< 64, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; // Disable until we fix accumulate_qkv for odd D/16 //case 80: // iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; // Disable until we fix accumulate_qkv for odd D/16 //case 112: // iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: - iqk_flash_helper_T<128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 192: + iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 256: - iqk_flash_helper_T<256, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; default: return false; } diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index 6e27c614..b24dc7b2 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -23,7 +23,8 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, bool iqk_flash_attn_noalibi(int type_k, // type of k int type_v, // type of v - int D, // head size + int Dk, // K head size + int Dv, // V head size int nq, // number of columns in q int nk, // number of rows in k int stride_q, // distance between q columns in bytes diff --git a/src/llama.cpp b/src/llama.cpp index b2553802..0817c53c 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17768,10 +17768,10 @@ struct llama_context * llama_new_context_with_model( params.flash_attn = false; } - if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) { - LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__); - params.flash_attn = false; - } + //if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) { + // LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__); + // params.flash_attn = false; + //} if (params.type_v != GGML_TYPE_F16 && params.type_v != GGML_TYPE_BF16 && !params.flash_attn) { LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); From e974fc9e6691825611c781525ae16d77eed3cbc0 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 12 Feb 2025 07:20:38 +0200 Subject: [PATCH 004/132] Fix imatrix overprotectiveness (#202) Co-authored-by: Iwan Kawrakow --- examples/imatrix/imatrix.cpp | 39 ++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 497c9d14..8006988c 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -39,6 +39,7 @@ struct Stats { std::vector values; std::vector counts; int ncall = 0; + int n_as = 1; }; class IMatrixCollector { @@ -132,11 +133,15 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * if (e.values.empty()) { e.values.resize(src1->ne[0]*n_as, 0); e.counts.resize(src1->ne[0]*n_as, 0); + e.n_as = n_as; } else if (e.values.size() != (size_t)src1->ne[0]*n_as) { fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as); exit(1); //GGML_ABORT("fatal error"); } + else if (e.n_as != n_as) { + fprintf(stderr, "Oops: inconsistent n_as for %s (%d vs %d)\n", wname.c_str(), e.n_as, n_as); + } if (m_params.verbosity > 1) { printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type); } @@ -258,8 +263,38 @@ void IMatrixCollector::save_imatrix(int ncall) const { } if (n_zeros > 0) { - fprintf(stderr, "%s: entry '%40s' has partial data (%.2f%%) - skipping\n", __func__, kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all); - continue; + fprintf(stderr, "%s: entry '%40s' has partial data (%.2f%%)", __func__, kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all); + bool store_it = false; + if (kv.second.n_as > 1) { + int n_per_expert = n_all / kv.second.n_as; + std::vector bad_experts; + bad_experts.reserve(kv.second.n_as); + for (int i = 0; i < kv.second.n_as; ++i) { + auto counts = kv.second.counts.data() + i*n_per_expert; + int nz_i = 0; + for (int j = 0; j < n_per_expert; ++j) { + if (counts[j] == 0) ++nz_i; + } + if (nz_i > 0) bad_experts.push_back(i); + } + fprintf(stderr, " %d out of %d experts are missing data", int(bad_experts.size()), kv.second.n_as); + if (bad_experts.size() < round(kv.second.n_as * 0.05)) { + fprintf(stderr, " Storing **but be aware**\n"); + store_it = true; + for (auto i : bad_experts) { + auto counts = (int *)kv.second.counts.data() + i*n_per_expert; + auto values = (float *)kv.second.values.data() + i*n_per_expert; + for (int j = 0; j < n_per_expert; ++j) { + counts[j] = 1; + values[j] = 1; + } + } + } + } + if (!store_it) { + fprintf(stderr, " - skipping\n"); + continue; + } } n_entries++; From 1bbb543478bbc0997c3f86581c4f95338a5eb5c3 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 12 Feb 2025 14:22:26 +0200 Subject: [PATCH 005/132] Fix iqk_mul_mat on AVX512 systems that are missing BF16 support (#204) * Fix iqk_mul_mat on AVX512 systems that are missing BF16 support * One more --------- Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_mul_mat.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 3b58495e..e39c27d9 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -8082,6 +8082,9 @@ struct QFBase { using Acc = __m512; static inline Data load(const ggml_half * x) { return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)x)); } static inline Data load(const float * x) { return _mm512_loadu_ps(x); } + static inline Data load(const ggml_bf16_t * x) { + return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)x)), 16)); + } static inline Acc acc(Acc prev, const Data& y, const Data& x) { return _mm512_fmadd_ps(y, x, prev); } @@ -16079,7 +16082,7 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, HelperF16 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512BF16__ case GGML_TYPE_BF16: { HelperBF16 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); From 05242ff17d3685321ea0ea12021f77609219f2a6 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 13 Feb 2025 11:50:20 +0200 Subject: [PATCH 006/132] Faster MLA prompt processing (#205) * Do not allocate / report caches that are not used It is either the standard KV cache or MLA cache, not both. * Rename X_pe to X_rope Much easier to follow, at least for my brain, when we have X_rope : rotational position encoding X_nope : no position encoding instead of X_pe and X_nope, where I was wondering wtf is 'pe' and 'nope'. * WIP * WIP * WIP * WIP * Warn user when disabling MLA * MLA: compile time option to not use transposed KV cache Cuts KV cache size in nearly half at the expense of slower TG performance for long contexts (it becomes similar to no-MLA). --------- Co-authored-by: Iwan Kawrakow --- src/llama.cpp | 331 ++++++++++++++++++++++++-------------------------- 1 file changed, 159 insertions(+), 172 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 0817c53c..498bb437 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -108,6 +108,14 @@ #define LLAMA_MAX_LAYERS 512 #define LLAMA_MAX_EXPERTS 256 // DeepSeekV2 +// +// === MLA cache +// If tou are desperate to reduce KV cache size, set MLA_USE_TRANSPOSED_CACHE to 0. +// TG perfornce will be slower (similar to no-MLA), but KV cache size will be cut to ~half. +// PP performance will be about the same as with MLA_USE_TRANSPOSED_CACHE = 1. +// +#define MLA_USE_TRANSPOSED_CACHE 1 + // // helpers // @@ -2547,7 +2555,7 @@ struct llama_layer { struct ggml_tensor * wkv_a_mqa; struct ggml_tensor * wkv_b; struct ggml_tensor * wk_b; - struct ggml_tensor * wv_b; + struct ggml_tensor * wv_b; struct ggml_tensor * wq_cross; struct ggml_tensor * wk_cross; struct ggml_tensor * wv_cross; @@ -2676,18 +2684,16 @@ struct llama_kv_cache { ggml_type type_k = GGML_TYPE_F16; ggml_type type_v = GGML_TYPE_F16; - ggml_type type_kr = GGML_TYPE_F16; - ggml_type type_kv = GGML_TYPE_F16; - std::vector cells; std::vector k_l; // per layer std::vector v_l; // DeepSeek MLA - std::vector kr_l; // per layer std::vector kv_l; +#if MLA_USE_TRANSPOSED_CACHE std::vector kvt_l; +#endif std::vector ctxs; std::vector bufs; @@ -3121,8 +3127,6 @@ static bool llama_kv_cache_init( cache.type_k = type_k; cache.type_v = type_v; - cache.type_kr = type_k; - cache.type_kv = type_v; cache.cells.clear(); cache.cells.resize(kv_size); @@ -3166,10 +3170,13 @@ static bool llama_kv_cache_init( cache.v_l.reserve(n_layer); // DeepSeek MLA - cache.kr_l.reserve(n_layer); cache.kv_l.reserve(n_layer); +#if MLA_USE_TRANSPOSED_CACHE cache.kvt_l.reserve(n_layer); +#endif + bool warn = true; + int n_mla = 0; for (int i = 0; i < (int) n_layer; i++) { const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); @@ -3177,34 +3184,53 @@ static bool llama_kv_cache_init( struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); ggml_tensor * k; ggml_tensor * v; + if (cparams.mla_attn) { + if (!model.layers[i].wk_b || !model.layers[i].wv_b) { + if (warn) { + LLAMA_LOG_WARN("=======================================================================================\n"); + LLAMA_LOG_WARN("%s: missing MLA tensors => disabling MLA\n", __func__); + LLAMA_LOG_WARN("%s: you need to reconvert your model in order to use MLA\n", __func__); + LLAMA_LOG_WARN("=======================================================================================\n"); + warn = false; + } + } + } if (cparams.mla_attn && model.layers[i].wk_b && model.layers[i].wv_b) { - k = ggml_new_tensor_1d(ctx, type_k, 1); - v = ggml_new_tensor_1d(ctx, type_v, 1); + // DeepSeek MLA + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); +#if MLA_USE_TRANSPOSED_CACHE + // TODO: The k-cache is contiguous and not permuted, so strictly speaking, it should be possible to quantize it. + // Sadly, at this point something goes wrong with quantized k-cache, so for now we set the k-cache + // type to type_v, which is guaranteed to be f16 or bf16 without FA. + //ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size); + ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_v, (kv_lora_rank + n_embd_head_qk_rope)*kv_size); +#else + ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_v, (kv_lora_rank + n_embd_head_qk_rope)*kv_size); +#endif + ggml_format_name(kv, "cache_kv_l%d", i); + cache.kv_l.push_back(kv); +#if MLA_USE_TRANSPOSED_CACHE + ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_v, kv_lora_rank*kv_size); + ggml_format_name(kvt, "cache_kvt_l%d", i); + cache.kvt_l.push_back(kvt); +#endif + n_mla++; } else { - k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); - } - - ggml_format_name(k, "cache_k_l%d", i); - ggml_format_name(v, "cache_v_l%d", i); - cache.k_l.push_back(k); - cache.v_l.push_back(v); - - - // DeepSeek MLA - const uint32_t n_embd_head_qk_rope = hparams.n_rot; - const uint32_t kv_lora_rank = hparams.n_lora_kv; - LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); - ggml_tensor * kr = ggml_new_tensor_1d(ctx, cache.type_kr, n_embd_head_qk_rope*kv_size); - ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size); - ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size); - ggml_format_name(kr, "cache_kr_l%d", i); - ggml_format_name(kv, "cache_kv_l%d", i); - ggml_format_name(kvt, "cache_kvt_l%d", i); - cache.kr_l.push_back(kr); - cache.kv_l.push_back(kv); - cache.kvt_l.push_back(kvt); + k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_format_name(k, "cache_k_l%d", i); + ggml_format_name(v, "cache_v_l%d", i); + cache.k_l.push_back(k); + cache.v_l.push_back(v); + } + } + if (cparams.mla_attn && n_mla < n_layer && n_mla > 0) { + LLAMA_LOG_ERROR("%s: unexpected situation with %d out of %d layers having MLA enabled\n", __func__, n_mla, int(n_layer)); + LLAMA_LOG_ERROR("%s: bailing out\n", __func__); + GGML_ABORT("fatal error"); } // allocate tensors and initialize the buffers to avoid NaNs in the padding @@ -13422,94 +13448,80 @@ struct llm_build_context { cb(q_nope, "q_nope", il); // and {n_head * n_embd_head_qk_rope, n_tokens} - struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, + struct ggml_tensor * q_rope = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(q->type, hparams.n_embd_head_k), ggml_row_size(q->type, hparams.n_embd_head_k * n_head), ggml_row_size(q->type, n_embd_head_qk_nope)); - cb(q_pe, "q_pe", il); + cb(q_rope, "q_rope", il); + + q_rope = ggml_rope_ext( + ctx0, q_rope, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(q_rope, "q_rope", il); // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens} - struct ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); - cb(kv_pe_compresseed, "kv_pe_compresseed", il); + struct ggml_tensor * kv_rope_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_rope_compresseed, "kv_rope_compresseed", il); + + // and {n_embd_head_qk_rope, n_tokens} + struct ggml_tensor * k_rope = ggml_view_3d(ctx0, kv_rope_compresseed, n_embd_head_qk_rope, 1, n_tokens, + kv_rope_compresseed->nb[1], + kv_rope_compresseed->nb[1], + ggml_row_size(kv_rope_compresseed->type, kv_lora_rank)); + cb(k_rope, "k_rope", il); + + // shared RoPE key + k_rope = ggml_rope_ext( + ctx0, k_rope, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(k_rope, "k_rope", il); // split into {kv_lora_rank, n_tokens} - struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens, - kv_pe_compresseed->nb[1], + struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_rope_compresseed, kv_lora_rank, n_tokens, + kv_rope_compresseed->nb[1], 0); cb(kv_compressed, "kv_compressed", il); + kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams, + model.layers[il].attn_kv_a_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(kv_compressed, "kv_compressed", il); + if (lctx.cparams.mla_attn && model.layers[il].wk_b && model.layers[il].wv_b) { - // and {n_embd_head_qk_rope, n_tokens} - struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, - kv_pe_compresseed->nb[1], - kv_pe_compresseed->nb[1], - ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); - cb(k_pe, "k_pe", il); - - //kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm - kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams, - model.layers[il].attn_kv_a_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(kv_compressed, "kv_compressed", il); - - struct ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)*kv_head); - cb(kv_cache_view, "kv_cache_view", il); - - // note: storing c^KV in the KV cache - ggml_build_forward_expand(gf, ggml_cpy(ctx0, kv_compressed, kv_cache_view)); - - struct ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), ggml_row_size(kv_self.kv_l[il]->type, kv_head)); +#if MLA_USE_TRANSPOSED_CACHE + ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, + ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), ggml_row_size(kv_self.kv_l[il]->type, kv_head)); cb(kv_cache_trans_view, "kv_cache_trans_view", il); // note: storing transposed c^KV in the transposed KV cache ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view)); - struct ggml_tensor * kv_cache = - ggml_view_2d(ctx0, kv_self.kv_l[il], - kv_lora_rank, n_kv, - ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank), - 0); + ggml_tensor * kv_cache_trans = ggml_view_2d(ctx0, kv_self.kvt_l[il], + n_kv, kv_lora_rank, + ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), + 0); + cb(kv_cache_trans, "kv_cache_trans", il); +#endif + + ggml_tensor * kvr = ggml_concat(ctx0, kv_compressed, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), 0); + cb(kvr, "kvr", il); + + ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*(kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope)*kv_head); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, kvr, kv_cache_view)); + ggml_tensor * kv_cache = ggml_view_2d(ctx0, kv_self.kv_l[il], + kv_lora_rank + n_embd_head_qk_rope, n_kv, + ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); cb(kv_cache, "kv_cache", il); - struct ggml_tensor * kv_cache_trans = - ggml_view_2d(ctx0, kv_self.kvt_l[il], - n_kv, kv_lora_rank, - ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), - 0); - cb(kv_cache_trans, "kv_cache_trans", il); - - //q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE - q_pe = ggml_rope_ext( - ctx0, q_pe, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); - cb(q_pe, "q_pe", il); - - // shared RoPE key - //k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE - k_pe = ggml_rope_ext( - ctx0, k_pe, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); - cb(k_pe, "k_pe", il); - - struct ggml_tensor * kr_cache_view = ggml_view_1d(ctx0, kv_self.kr_l[il], n_tokens*n_embd_head_qk_rope, ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope)*kv_head); - cb(kr_cache_view, "kr_cache_view", il); - - // note: storing RoPE-ed version of K^R in the KV cache - ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_pe, kr_cache_view)); - - struct ggml_tensor * kr_cache = - ggml_view_2d(ctx0, kv_self.kr_l[il], - n_embd_head_qk_rope, n_kv, - ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope), - 0); - cb(kr_cache, "kr_cache", il); - - struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), 0); + struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, + ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), + ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank)*n_embd_head_qk_nope, 0); cb(wk_b, "wk_b", il); q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); @@ -13518,33 +13530,20 @@ struct llm_build_context { struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope); cb(q_nope2, "q_nope2", il); + ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0); + cb(q, "q", il); if (!pp_opt) { - q_nope2 = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3); - cb(q_nope2, "q_nope2_perm", il); + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + cb(q, "q_perm", il); } - struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2); - cb(kq_nope, "kq_nope", il); - - if (!pp_opt) { - kq_nope = ggml_permute(ctx0, kq_nope, 0, 2, 1, 3); - cb(kq_nope, "kq_nope_perm", il); - } - - if (pp_opt) { - q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3); - cb(q_pe, "q_pe_perm", il); - } - struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe); - cb(kq_pe, "kq_pe", il); - - if (!pp_opt) { - kq_pe = ggml_permute(ctx0, kq_pe, 0, 2, 1, 3); - cb(kq_pe, "kq_pe_perm", il); - } - - struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe); + ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q); cb(kq, "kq", il); + if (!pp_opt) { + kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); + cb(kq, "kq_perm", il); + } + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); @@ -13553,6 +13552,16 @@ struct llm_build_context { cb(kq, "kq_soft_max_ext_perm", il); } +#if !MLA_USE_TRANSPOSED_CACHE + ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], + kv_lora_rank, n_kv, + ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); + cb(kv_cache, "kv_cache_lora", il); + + ggml_tensor * kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora)); + cb(kv_cache_trans, "kv_cache_trans", il); +#endif + struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq); cb(kqv_compressed, "kqv_compressed", il); @@ -13561,7 +13570,9 @@ struct llm_build_context { cb(kqv_compressed, "kqv_compressed_perm", il); } - struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0); + struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, + ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), + ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank)*n_embd_head_v, 0); cb(wv_b, "wv_b", il); struct ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed); @@ -13581,19 +13592,6 @@ struct llm_build_context { } else { - // and {n_embd_head_qk_rope, n_tokens} - struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, - kv_pe_compresseed->nb[1], - kv_pe_compresseed->nb[1], - ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); - cb(k_pe, "k_pe", il); - - //kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm - kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams, - model.layers[il].attn_kv_a_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(kv_compressed, "kv_compressed", il); - // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); cb(kv, "kv", il); @@ -13620,27 +13618,10 @@ struct llm_build_context { 0); cb(v_states, "v_states", il); - //q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE - q_pe = ggml_rope_ext( - ctx0, q_pe, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); - cb(q_pe, "q_pe", il); - - // shared RoPE key - //k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE - k_pe = ggml_rope_ext( - ctx0, k_pe, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); - cb(k_pe, "k_pe", il); - - struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); + struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_rope, 0); cb(q_states, "q_states", il); - struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); + struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_rope, q_rope), 0); cb(k_states, "k_states", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, @@ -18054,28 +18035,34 @@ struct llama_context * llama_new_context_with_model( memory_size_v += ggml_nbytes(v); } - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, - (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), - ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), - ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + if (memory_size_k + memory_size_v > 0) { + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } } - { - size_t memory_size_kr = 0; + { size_t memory_size_kv = 0; - - for (auto & kr : ctx->kv_self.kr_l) { - memory_size_kr += ggml_nbytes(kr); - } + size_t memory_size_kvt = 0; for (auto & kv : ctx->kv_self.kv_l) { memory_size_kv += ggml_nbytes(kv); } - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K^R (%s): %7.2f MiB, c^KV (%s): %7.2f MiB\n", __func__, - (float)(memory_size_kr + memory_size_kv) / (1024.0f * 1024.0f), - ggml_type_name(type_k), (float)memory_size_kr / (1024.0f * 1024.0f), - ggml_type_name(type_k), (float)memory_size_kv / (1024.0f * 1024.0f)); +#if MLA_USE_TRANSPOSED_CACHE + for (auto & kvt : ctx->kv_self.kvt_l) { + memory_size_kvt += ggml_nbytes(kvt); + } +#endif + + if (memory_size_kv + memory_size_kvt > 0) { + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__, + (float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_kv / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_kvt / (1024.0f * 1024.0f)); + } } // graph outputs buffer From 8e94b29e35d69a6721c35767222d16dc65df2db6 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 13 Feb 2025 14:44:33 +0200 Subject: [PATCH 007/132] MLA: allow Q8_0 K-cache for MLA (#206) Co-authored-by: Iwan Kawrakow --- src/llama.cpp | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 498bb437..8c4a966d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3201,11 +3201,7 @@ static bool llama_kv_cache_init( const uint32_t kv_lora_rank = hparams.n_lora_kv; LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); #if MLA_USE_TRANSPOSED_CACHE - // TODO: The k-cache is contiguous and not permuted, so strictly speaking, it should be possible to quantize it. - // Sadly, at this point something goes wrong with quantized k-cache, so for now we set the k-cache - // type to type_v, which is guaranteed to be f16 or bf16 without FA. - //ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size); - ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_v, (kv_lora_rank + n_embd_head_qk_rope)*kv_size); + ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size); #else ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_v, (kv_lora_rank + n_embd_head_qk_rope)*kv_size); #endif @@ -13495,7 +13491,7 @@ struct llm_build_context { #if MLA_USE_TRANSPOSED_CACHE ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, - ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), ggml_row_size(kv_self.kv_l[il]->type, kv_head)); + ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size), ggml_row_size(kv_self.kvt_l[il]->type, kv_head)); cb(kv_cache_trans_view, "kv_cache_trans_view", il); // note: storing transposed c^KV in the transposed KV cache @@ -13503,7 +13499,7 @@ struct llm_build_context { ggml_tensor * kv_cache_trans = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_kv, kv_lora_rank, - ggml_row_size(kv_self.kv_l[il]->type, kv_self.size), + ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size), 0); cb(kv_cache_trans, "kv_cache_trans", il); #endif @@ -18047,21 +18043,26 @@ struct llama_context * llama_new_context_with_model( size_t memory_size_kv = 0; size_t memory_size_kvt = 0; + ggml_type kv_type = GGML_TYPE_COUNT; + ggml_type kvt_type = GGML_TYPE_COUNT; + for (auto & kv : ctx->kv_self.kv_l) { memory_size_kv += ggml_nbytes(kv); + kv_type = kv->type; } #if MLA_USE_TRANSPOSED_CACHE for (auto & kvt : ctx->kv_self.kvt_l) { memory_size_kvt += ggml_nbytes(kvt); + kvt_type = kvt->type; } #endif if (memory_size_kv + memory_size_kvt > 0) { LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__, (float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f), - ggml_type_name(type_v), (float)memory_size_kv / (1024.0f * 1024.0f), - ggml_type_name(type_v), (float)memory_size_kvt / (1024.0f * 1024.0f)); + ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f), + ggml_type_name(kvt_type), (float)memory_size_kvt / (1024.0f * 1024.0f)); } } From 0551e7630b2e74dfd140aadbf6a136c58199f474 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 15 Feb 2025 08:45:45 +0200 Subject: [PATCH 008/132] Moving 4D gemm logic from ggml.c to iqk_mul_mat.cpp (#207) This allows us to optimize TG performance for GQA models. E.g., for IQ4_XS L3-8B with 8k TG-64 goes from 8.6 to 10.26 t/s. Co-authored-by: Iwan Kawrakow --- ggml/src/ggml.c | 76 +++++++++++++++++------------------- ggml/src/iqk/iqk_mul_mat.cpp | 59 ++++++++++++++++++++++++++++ ggml/src/iqk/iqk_mul_mat.h | 7 ++++ 3 files changed, 101 insertions(+), 41 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 7b631177..045ea446 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -14069,22 +14069,12 @@ static void ggml_compute_forward_mul_mat( #if GGML_USE_IQK_MULMAT if (dst->type == GGML_TYPE_F32) { - int gcd = simple_gcd(ne12*ne13, nth); - int counter = 0; - for (int64_t i13 = 0; i13 < ne13; i13++) { - for (int64_t i12 = 0; i12 < ne12; i12++) { - if ((counter++ % gcd) == (ith%gcd)) { - if (!iqk_mul_mat(ne01, ne11, ne00, - src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), - src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11, ///ggml_type_size(src1->type), - (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), - ith/gcd, nth/gcd)) goto IQK_MulMat_Not_Available1; - } - } - } - return; + if (iqk_mul_mat_4d(ne01, ne11, ne00, + ne02, ne03, ne12, ne13, nb02, nb03, nb12, nb13, nb2/sizeof(float), nb3/sizeof(float), + src0->type, src0->data, nb01, + src1->type, src1->data, nb11, + (float *)dst->data, nb1/sizeof(float), ith, nth)) return; } -IQK_MulMat_Not_Available1:; #endif #if GGML_USE_LLAMAFILE @@ -14125,6 +14115,29 @@ UseGgmlGemm1:; assert(params->wsize >= ne13*nbw3); GGML_ASSERT(src1->type == GGML_TYPE_F32); +#ifdef GGML_USE_IQK_MULMAT + int ts = type_traits[vec_dot_type].type_size; + int bs = type_traits[vec_dot_type].blck_size; + int64_t blocks_per_row = ne10/bs; + int64_t num_blocks = ne11*ne12*ne13*blocks_per_row; + int gcd = simple_gcd(128, ts); // 128 is to cover cache line sizes for common architectures without getting involved + // with trying to get it from ggml + int64_t num_blocks_gcd = (num_blocks + gcd - 1)/gcd; + int64_t block_per_thread = ((num_blocks_gcd + nth - 1)/nth)*gcd; + int64_t first_block = ith*block_per_thread; + int64_t last_block = MIN(num_blocks, first_block + block_per_thread); + while (first_block < last_block) { + int64_t i13 = first_block/(ne11*ne12*blocks_per_row); + int64_t i12 = (first_block - i13*ne11*ne12*blocks_per_row)/(ne11*blocks_per_row); + int64_t i11 = (first_block - (i13*ne12 + i12)*ne11*blocks_per_row)/blocks_per_row; + int64_t i10 = first_block % blocks_per_row; + int64_t blocks_to_do = MIN(blocks_per_row - i10, last_block - first_block); + from_float((float *)((char *)src1->data + i13*nb13 + i12*nb12 + i11*nb11) + i10*bs, + (void *)(wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + i10*ts), blocks_to_do*bs); + first_block += blocks_to_do; + } +#else + for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { int64_t i11_processed = 0; @@ -14145,6 +14158,7 @@ UseGgmlGemm1:; } } } +#endif ggml_barrier(params->shared); @@ -14165,34 +14179,14 @@ UseGgmlGemm1:; #if GGML_USE_IQK_MULMAT if (src1->type != vec_dot_type && dst->type == GGML_TYPE_F32) { - // When K*Q and V*softmax(K*Q) (so ne12*ne13 > 1), it is better (faster) to have fewer threads processing - // one matrix multiplication, but work on several heads at once. - // Hence, we find the GCD(n12*ne13, nth) and have nth/GCD(n12*ne13, nth) threads per head. - // Leaving the previous version commented out for now just in case. const size_t row_size = ggml_row_size(vec_dot_type, ne10); - int ntg = simple_gcd(ne12*ne13, nth); - int counter = 0; - for (int64_t i13 = 0; i13 < ne13; i13++) { - for (int64_t i12 = 0; i12 < ne12; i12++) { - if (counter++ % ntg == ith%ntg) { - if (!iqk_mul_mat(ne01, ne11, ne00, - src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), - vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size, ///ggml_type_size(vec_dot_type), - (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), - ith/ntg, nth/ntg)) goto IQK_MulMat_Not_Available2; - } - } - } - //for (int64_t i13 = 0; i13 < ne13; i13++) - // for (int64_t i12 = 0; i12 < ne12; i12++) - // if (!iqk_mul_mat(ne01, ne11, ne00, - // src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), - // vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size, ///ggml_type_size(vec_dot_type), - // (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), - // ith, nth)) goto IQK_MulMat_Not_Available2; - return; + if (iqk_mul_mat_4d(ne01, ne11, ne00, + ne02, ne03, ne12, ne13, nb02, nb03, row_size*ne11, row_size*ne11*ne12, + nb2/sizeof(float), nb3/sizeof(float), + src0->type, src0->data, nb01, + vec_dot_type, wdata, row_size, + (float *)dst->data, nb1/sizeof(float), ith, nth)) return; } -IQK_MulMat_Not_Available2:; #endif #if GGML_USE_LLAMAFILE diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index e39c27d9..8d6b45da 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -340,6 +340,56 @@ bool iqk_mul_mat(long Nx, long Ny, long ne00, return true; } +namespace { +inline uint32_t simple_gcd(uint32_t a, uint32_t b) { + while (a != b) { + if (a > b) a -= b; + else b -= a; + } + return a; +} +} + +bool iqk_mul_mat_4d(long Nx, long Ny, long ne00, + long ne02, long ne03, long ne12, long ne13, + long nb02, long nb03, long nb12, long nb13, long nb2, long nb3, + int typeA, const void * A, long strideA, + int typeB, const void * B, long strideB, + float * C, long stride_C, int ith, int nth) { + + auto r2 = ne12 / ne02; + auto r3 = ne13 / ne03; + + if (ne13 == 1 && Ny == 1 && r2 > 1) { + int gcd = simple_gcd(ne02, nth); + int counter = 0; + for (int64_t i12 = 0; i12 < ne02; i12++) { + if ((counter++ % gcd) == (ith%gcd)) { + if (!iqk_mul_mat(Nx, r2, ne00, + typeA, (const char *)A + i12*nb02, strideA, + typeB, (const char *)B + i12*r2*nb12, nb12, + C + r2*i12*nb2, nb2, + ith/gcd, nth/gcd)) return false; + } + } + return true; + } + int gcd = simple_gcd(ne12*ne13, nth); + int counter = 0; + for (int64_t i13 = 0; i13 < ne13; i13++) { + for (int64_t i12 = 0; i12 < ne12; i12++) { + if ((counter++ % gcd) == (ith%gcd)) { + if (!iqk_mul_mat(Nx, Ny, ne00, + typeA, (const char *)A + i12/r2*nb02 + i13/r3*nb03, strideA, + typeB, (const char *)B + i12*nb12 + i13*nb13, strideB, + C + i12*nb2 + i13*nb3, stride_C, + ith/gcd, nth/gcd)) return false; + } + } + } + return true; +} + bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const void * A, long strideA, int typeB, const void * B, long strideB, @@ -16292,6 +16342,15 @@ bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void return false; } +bool iqk_mul_mat_4d(long /*Nx*/, long /*Ny*/, long /*ne00*/, + long /*ne02*/, long /*ne03*/, long /*ne12*/, long /*ne13*/, + long /*nb02*/, long /*nb03*/, long /*nb12*/, long /*nb13*/, long /*nb2*/, long /*nb3*/, + int /*typeA*/, const void * /*A*/, long /*strideA*/, + int /*typeB*/, const void * /*B*/, long /*strideB*/, + float * /*C*/, long /*stride_C*/, int /*ith*/, int /*nth*/) { + return false; +} + bool iqk_mul_mat_moe(long, long, long, int, int, const void *, long, int, const void *, long, float *, long, long, const void *, int, int) { return false; diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index b24dc7b2..d5f340b2 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -16,6 +16,13 @@ bool iqk_mul_mat(long Nx, long Ny, long ne00, int typeB, const void * B, long strideB, float * C, long stride_C, int ith, int nth); +bool iqk_mul_mat_4d(long Nx, long Ny, long ne00, + long ne02, long ne03, long ne12, long ne13, + long nb02, long nb03, long nb12, long nb13, long nb2, long nb3, + int typeA, const void * A, long strideA, + int typeB, const void * B, long strideB, + float * C, long stride_C, int ith, int nth); + bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const void * A, long strideA, int typeB, const void * B, long strideB, From d44aba79ea9bea07c22cbf2336b51a37ba823524 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 15 Feb 2025 19:50:53 +0200 Subject: [PATCH 009/132] Bug fix in activation quantization I added a change in the last PR how activations are quantized. It looked like it is working and slightly improving performance. But I now hit an edge case where I get gibberish that goes away if I remove the change. I absolutely don't see what goes wrong, so leaving the change in commented out for now. --- ggml/src/ggml.c | 46 +++++++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 045ea446..8ab6b0a9 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -14115,28 +14115,28 @@ UseGgmlGemm1:; assert(params->wsize >= ne13*nbw3); GGML_ASSERT(src1->type == GGML_TYPE_F32); -#ifdef GGML_USE_IQK_MULMAT - int ts = type_traits[vec_dot_type].type_size; - int bs = type_traits[vec_dot_type].blck_size; - int64_t blocks_per_row = ne10/bs; - int64_t num_blocks = ne11*ne12*ne13*blocks_per_row; - int gcd = simple_gcd(128, ts); // 128 is to cover cache line sizes for common architectures without getting involved - // with trying to get it from ggml - int64_t num_blocks_gcd = (num_blocks + gcd - 1)/gcd; - int64_t block_per_thread = ((num_blocks_gcd + nth - 1)/nth)*gcd; - int64_t first_block = ith*block_per_thread; - int64_t last_block = MIN(num_blocks, first_block + block_per_thread); - while (first_block < last_block) { - int64_t i13 = first_block/(ne11*ne12*blocks_per_row); - int64_t i12 = (first_block - i13*ne11*ne12*blocks_per_row)/(ne11*blocks_per_row); - int64_t i11 = (first_block - (i13*ne12 + i12)*ne11*blocks_per_row)/blocks_per_row; - int64_t i10 = first_block % blocks_per_row; - int64_t blocks_to_do = MIN(blocks_per_row - i10, last_block - first_block); - from_float((float *)((char *)src1->data + i13*nb13 + i12*nb12 + i11*nb11) + i10*bs, - (void *)(wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + i10*ts), blocks_to_do*bs); - first_block += blocks_to_do; - } -#else +//#ifdef GGML_USE_IQK_MULMAT +// int ts = type_traits[vec_dot_type].type_size; +// int bs = type_traits[vec_dot_type].blck_size; +// int64_t blocks_per_row = ne10/bs; +// int64_t num_blocks = ne11*ne12*ne13*blocks_per_row; +// int gcd = simple_gcd(128, ts); // 128 is to cover cache line sizes for common architectures without getting involved +// // with trying to get it from ggml +// int64_t num_blocks_gcd = (num_blocks + gcd - 1)/gcd; +// int64_t block_per_thread = ((num_blocks_gcd + nth - 1)/nth)*gcd; +// int64_t first_block = ith*block_per_thread; +// int64_t last_block = MIN(num_blocks, first_block + block_per_thread); +// while (first_block < last_block) { +// int64_t i13 = first_block/(ne11*ne12*blocks_per_row); +// int64_t i12 = (first_block - i13*ne11*ne12*blocks_per_row)/(ne11*blocks_per_row); +// int64_t i11 = (first_block - (i13*ne12 + i12)*ne11*blocks_per_row)/blocks_per_row; +// int64_t i10 = first_block % blocks_per_row; +// int64_t blocks_to_do = MIN(blocks_per_row - i10, last_block - first_block); +// from_float((float *)((char *)src1->data + i13*nb13 + i12*nb12 + i11*nb11) + i10*bs, +// (void *)(wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + i10*ts), blocks_to_do*bs); +// first_block += blocks_to_do; +// } +//#else for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { @@ -14158,7 +14158,7 @@ UseGgmlGemm1:; } } } -#endif +//#endif ggml_barrier(params->shared); From 047ba895bb3d94f055756c1ec7767b3342cb9c90 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 19 Feb 2025 10:01:49 +0200 Subject: [PATCH 010/132] Repack also experts (#210) Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_quantize.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index f33fc183..24b49d89 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -6507,7 +6507,7 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) { if (!tensor) return; if (!ggml_is_contiguous(tensor)) return; if (strncmp(tensor->name, "token_embd.weight", GGML_MAX_NAME) == 0) return; - if (tensor->ne[1] % 4 || tensor->ne[2]*tensor->ne[3] > 1) return; + if (tensor->ne[1] % 4) return; static const std::unordered_map k_map = { { GGML_TYPE_IQ2_K, { GGML_TYPE_IQ2_K_R4, 4, (Repack::repack_func)repack_iq2_k} }, { GGML_TYPE_IQ3_K, { GGML_TYPE_IQ3_K_R4, 4, (Repack::repack_func)repack_iq3_k} }, @@ -6544,8 +6544,10 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) { auto& r = it->second; + auto nrows = ggml_nrows(tensor); + int max_thread = std::max(1, int(std::thread::hardware_concurrency()/2)); - int num_chunks = (tensor->ne[1] + kChunk*r.num_rows - 1)/(kChunk*r.num_rows); + int num_chunks = (nrows + kChunk*r.num_rows - 1)/(kChunk*r.num_rows); int nthread = std::min(num_chunks, max_thread); //printf("%s(%s): %s -> %s. %d rows, %d chunks, %d threads\n", __func__, tensor->name, ggml_type_name(tensor->type), ggml_type_name(r.new_type), @@ -6553,7 +6555,7 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) { std::atomic counter(0);; auto compute = [&counter, &r, tensor, num_chunks, chunkSize = kChunk] () { - int nrows = tensor->ne[1]; + int nrows = ggml_nrows(tensor); int n_per_row = tensor->ne[0]; auto row_size = ggml_row_size(tensor->type, n_per_row); std::vector qtmp(r.num_rows*row_size); From a0ebfdd661a2ccb2700b0e36cfc10ca1a2b4de98 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 19 Feb 2025 11:47:07 +0200 Subject: [PATCH 011/132] Q8_KV: 8-bit quantization type targeting the KV cache (#208) * Adding q8_KV - Basics + AVX2 gemm/gemv * q8_KV: Better AVX2 gemm * q8_KV: Better Zen4 gemm We get 225.7 t/s for L3-8B. In comparison q8_0 without run-tinme-repacking is at 169 t/s. * q8_KV: AVX2 gemm/gemv We get 254 t/s for L3-8B vs 194 t/s for q8_0 without rtr. * q8_KV: be able to use it for K cache This required quite a few fixes in ggml and llama.cpp: * ggml: do not calculate row size as n/block_size*type_size. I had removed most of it when implementing the quants with per row scale, bit it was stull lurking in ggml_copy. Not sure if these were the last remnants of ggmil-style row sizes, or if there are still places left * llama.cpp: get rid of the the 1d K cache assumption. Create and manage the K-cache as a 2D tensor so we can have per row meta data as needed by q8_KV. Using q8_KV for K-cache results in non-negligible performance gains. More details to follow, but for DeepSeek-Lite with MLA, we get 18% speedup for PP-8192 compared to q8_0 K-cache. * q8_KV: be able to use it for K cache in FA * q8_KV: repack it for K*Q in FA * q8_KV: slightly faster gemv on Zen4 * q8_KV: slightly faster gemv on Zen4 * q8_KV: ARM_NEON We get PP-512 = 167 t/s for L3-8B without interleaving! We do the interleaving on the fly, so I wonder if this could be done for other quants as well. * q8_KV: use it in FA on NEON * q8_KV_r8 - repacked q8_KV On Zen4 it is slower than q8_k_r8 (292 vs 370 t/s) This makes no sense whatsoever as the q8_KV_r8 GEMM is basically the q8_k_r8 GEMM with the unnecessary block stuff removed (so, one would think that it would be faster). * q8_KV_r8: don't use nrc_y = 16 on Zen4 This is faster - 350 t/s. Why? Much better than the 290 t/s we had before, but still slower than the 370 t/s for q8_k_r8. * q8_KV: nrc_y = 16 also doesn't pay off in FA * Minor --------- Co-authored-by: Iwan Kawrakow --- common/common.cpp | 3 + examples/llama-bench/llama-bench.cpp | 3 + examples/quantize/quantize.cpp | 2 + ggml/include/ggml.h | 4 + ggml/src/ggml-quants.c | 6 +- ggml/src/ggml.c | 48 ++- ggml/src/iqk/iqk_mul_mat.cpp | 598 ++++++++++++++++++++++++++- ggml/src/iqk/iqk_quantize.cpp | 290 ++++++++++++- ggml/src/iqk/iqk_quantize.h | 12 + include/llama.h | 2 + src/llama.cpp | 49 ++- 11 files changed, 983 insertions(+), 34 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 44678d7a..f7a6f76f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2259,6 +2259,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) { if (s == "q6_0") { return GGML_TYPE_Q6_0; } + if (s == "q8_KV") { + return GGML_TYPE_Q8_KV; + } throw std::runtime_error("Invalid cache type: " + s); } diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 95df06dc..0222c213 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -339,6 +339,9 @@ static ggml_type ggml_type_from_name(const std::string & s) { if (s == "q6_0") { return GGML_TYPE_Q6_0; } + if (s == "q8_KV") { + return GGML_TYPE_Q8_KV; + } return GGML_TYPE_COUNT; } diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 7ceee208..916f57ec 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -56,6 +56,7 @@ static const std::vector QUANT_OPTIONS = { { "Q5_0_R4", LLAMA_FTYPE_MOSTLY_Q5_0_R4, " 5.50 bpw quantization", }, { "Q6_0_R4", LLAMA_FTYPE_MOSTLY_Q6_0_R4, " 6.50 bpw quantization", }, { "Q8_0_R8", LLAMA_FTYPE_MOSTLY_Q8_0_R8, " 8.50 bpw quantization", }, + { "Q8_KV", LLAMA_FTYPE_MOSTLY_Q8_KV, " 8.00 bpw quantization", }, { "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", }, { "IQ4_KS", LLAMA_FTYPE_MOSTLY_IQ4_KS, " 4.25 bpw non-linear quantization", }, { "IQ4_KS_R4",LLAMA_FTYPE_MOSTLY_IQ4_KS_R4,"IQ4_KS repacked", }, @@ -82,6 +83,7 @@ static const std::vector QUANT_OPTIONS = { { "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, +0.0008 ppl @ LLaMA-v1-7B", }, { "Q6_K_R4", LLAMA_FTYPE_MOSTLY_Q6_K_R4, "Q6_K repacked", }, { "Q8_K_R8", LLAMA_FTYPE_MOSTLY_Q8_K_R8, "Q8_K repacked", }, + { "Q8_KV_R8", LLAMA_FTYPE_MOSTLY_Q8_KV_R8, "Q8_KV repacked", }, { "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", }, { "Q4_0_4_4", LLAMA_FTYPE_MOSTLY_Q4_0_4_4, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, { "Q4_0_4_8", LLAMA_FTYPE_MOSTLY_Q4_0_4_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 66bcb25a..d2131a15 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -416,6 +416,7 @@ extern "C" { GGML_TYPE_Q8_K32 = 148, GGML_TYPE_Q8_KR8 = 149, GGML_TYPE_Q8_K128 = 150, + GGML_TYPE_Q8_KV = 151, GGML_TYPE_Q4_0_R8 = 202, GGML_TYPE_Q5_0_R4 = 206, @@ -442,6 +443,7 @@ extern "C" { GGML_TYPE_IQ4_K_R4 = 339, GGML_TYPE_IQ5_K_R4 = 340, GGML_TYPE_IQ4_KS_R4 = 344, + GGML_TYPE_Q8_KV_R8 = 398, GGML_TYPE_Q8_K_R8 = 399, GGML_TYPE_COUNT, }; @@ -501,6 +503,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ4_KS = 137, // except 1d tensors GGML_FTYPE_MOSTLY_IQ2_KS = 138, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_KSS = 139, // except 1d tensors + GGML_FTYPE_MOSTLY_Q8_KV = 140, // except 1d tensors // GGML_FTYPE_MOSTLY_Q4_0_R8 = 202, // except 1d tensors GGML_FTYPE_MOSTLY_Q8_0_R8 = 207, // except 1d tensors @@ -527,6 +530,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ4_K_R4 = 332, // except 1d tensors GGML_FTYPE_MOSTLY_IQ5_K_R4 = 333, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_KS_R4 = 337, // except 1d tensors + GGML_FTYPE_MOSTLY_Q8_KV_R8 = 398, // except 1d tensors GGML_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors }; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index fe7de167..e8218e76 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -15214,8 +15214,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_IQ3_K_R4: break; case GGML_TYPE_IQ4_K_R4: break; case GGML_TYPE_IQ5_K_R4: break; - case GGML_TYPE_IQ4_KS_R4: break; - case GGML_TYPE_Q8_K_R8: break; + case GGML_TYPE_IQ4_KS_R4:break; + case GGML_TYPE_Q8_KV_R8: break; + case GGML_TYPE_Q8_K_R8: break; + case GGML_TYPE_Q8_KV: break; case GGML_TYPE_BF16_R16: break; case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 8ab6b0a9..0aee8dd4 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1362,6 +1362,30 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q8_K128, .row_meta_size = 0, }, + [GGML_TYPE_Q8_KV] = { + .type_name = "q8_KV", + .blck_size = 32, + .type_size = 32, + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q8_KV, + .from_float = quantize_row_q8_KV, + .from_float_ref = (ggml_from_float_t)quantize_row_q8_KV_ref, + .vec_dot = vec_dot_q8_KV_q8_KV, + .vec_dot_type = GGML_TYPE_Q8_KV, + .row_meta_size = 8, + }, + [GGML_TYPE_Q8_KV_R8] = { + .type_name = "q8_KV_r8", + .blck_size = 32, + .type_size = 32, + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q8_KV_r8, + .from_float = quantize_row_q8_KV_r8, + .from_float_ref = (ggml_from_float_t)quantize_row_q8_KV_r8_ref, + .vec_dot = vec_dot_q8_KV_r8_q8_KV, + .vec_dot_type = GGML_TYPE_Q8_KV, + .row_meta_size = 4, + }, [GGML_TYPE_Q8_K16] = { .type_name = "q8_K16", .blck_size = 64, @@ -4373,6 +4397,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break; case GGML_FTYPE_MOSTLY_Q6_0: wtype = GGML_TYPE_Q6_0; break; case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; + case GGML_FTYPE_MOSTLY_Q8_KV: wtype = GGML_TYPE_Q8_KV; break; case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break; case GGML_FTYPE_MOSTLY_Q2_K_R4: wtype = GGML_TYPE_Q2_K_R4; break; case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break; @@ -4384,6 +4409,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break; case GGML_FTYPE_MOSTLY_Q6_K_R4: wtype = GGML_TYPE_Q6_K_R4; break; case GGML_FTYPE_MOSTLY_Q8_K_R8: wtype = GGML_TYPE_Q8_K_R8; break; + case GGML_FTYPE_MOSTLY_Q8_KV_R8: wtype = GGML_TYPE_Q8_KV_R8; break; case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break; case GGML_FTYPE_MOSTLY_IQ2_XXS_R4: wtype = GGML_TYPE_IQ2_XXS_R4;break; case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break; @@ -9436,7 +9462,7 @@ static void ggml_compute_forward_dup_f16( float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; size_t id = 0; - size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); + size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type)); char * dst_ptr = (char *) dst->data; for (int i03 = 0; i03 < ne03; i03++) { @@ -9722,7 +9748,7 @@ static void ggml_compute_forward_dup_bf16( float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; size_t id = 0; - size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); + size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type)); char * dst_ptr = (char *) dst->data; for (int i03 = 0; i03 < ne03; i03++) { @@ -10042,7 +10068,7 @@ static void ggml_compute_forward_dup_f32( ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float; size_t id = 0; - size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); + size_t rs = ggml_row_size(dst->type, ne00); //nb0 * (ne00 / ggml_blck_size(dst->type)); char * dst_ptr = (char *) dst->data; for (int i03 = 0; i03 < ne03; i03++) { @@ -10936,6 +10962,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: @@ -11406,6 +11433,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: @@ -11573,6 +11601,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: @@ -14061,7 +14090,7 @@ static void ggml_compute_forward_mul_mat( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows -#if GGML_USE_IQK_MULMAT || GGML_USE_LLAMAFILE +#if GGML_USE_LLAMAFILE // broadcast factors const int64_t r2 = ne12 / ne02; const int64_t r3 = ne13 / ne03; @@ -14344,7 +14373,7 @@ static void ggml_compute_forward_mul_mat_id( char * wdata_src1_end = (src1->type == vec_dot_type) ? (char *) params->wdata : - (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t)); + (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, src1->ne[0])*ggml_nrows(src1), sizeof(int64_t)); struct mmid_row_mapping { int32_t i1; @@ -14768,6 +14797,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_Q5_1: case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_KV: case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: @@ -14779,6 +14809,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: @@ -15186,6 +15217,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: @@ -15473,6 +15505,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q5_1: case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_KV: case GGML_TYPE_Q8_1: case GGML_TYPE_Q8_0_X4: case GGML_TYPE_Q8_1_X4: @@ -15487,6 +15520,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: @@ -16116,6 +16150,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K_R4: case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_Q8_KR8: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS_R4: @@ -16159,6 +16194,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_Q8_K: case GGML_TYPE_Q8_K64: case GGML_TYPE_Q8_K128: + case GGML_TYPE_Q8_KV: case GGML_TYPE_Q8_K16: case GGML_TYPE_Q8_K32: case GGML_TYPE_Q4_0_4_4: @@ -22970,6 +23006,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q6_0: result = quantize_q6_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q8_KV: result = quantize_q8_KV(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q2_K_R4: result = quantize_q2_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; @@ -22981,6 +23018,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q6_K_R4: result = quantize_q6_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q8_K_R8: result = quantize_q8_k_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q8_KV_R8:result = quantize_q8_KV_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XXS_R4:result = quantize_iq2_xxs_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 8d6b45da..3bfded73 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -269,6 +269,8 @@ struct MulMat { case GGML_TYPE_IQ4_XS_R8: case GGML_TYPE_Q4_K_R4: case GGML_TYPE_Q5_K_R4: + case GGML_TYPE_Q8_KV: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_Q8_K_R8: return 8; case GGML_TYPE_Q4_0_R8: case GGML_TYPE_Q8_0_R8: @@ -301,6 +303,8 @@ struct MulMat { case GGML_TYPE_IQ4_XS_R8: case GGML_TYPE_Q4_0_R8: case GGML_TYPE_Q8_0_R8: + case GGML_TYPE_Q8_KV: + case GGML_TYPE_Q8_KV_R8: case GGML_TYPE_Q8_K_R8: return 8; case GGML_TYPE_BF16_R16: return 16; default: return 1; @@ -6107,7 +6111,7 @@ static void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn // The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__) template static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); + GGML_ASSERT(nrc_x%8 == 0); Q8 q8(info); #ifndef HAVE_FANCY_SIMD auto m1 = _mm256_set1_epi16(1); @@ -6169,6 +6173,230 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn } } +// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__) +template +static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + GGML_ASSERT(n%32 == 0); +#ifndef HAVE_FANCY_SIMD + auto m1 = _mm256_set1_epi16(1); +#endif + int nb = n / 16; + __m256i acc[nrc_y] = {}; + __m256i qx[4]; + float dy[nrc_y]; +#ifdef HAVE_FANCY_SIMD + float sy[nrc_y]; +#endif + const int8_t * q8y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + dy[iy] = dptr[0]; +#ifdef HAVE_FANCY_SIMD + auto iptr = (const int32_t *)(dptr + 1); + sy[iy] = -127*iptr[0]; +#endif + q8y[iy] = (const int8_t *)(dptr + 2); + } + for (int ix = 0; ix < nrc_x; ix += 8) { + auto dptr = (const float *)((const char *)vx + ix*bx); + auto dx = _mm256_loadu_ps(dptr); + auto q8x = (const int8_t *)(dptr + 8); + for (int ib = 0; ib < nb; ++ib) { // Blocks of 16 for 8 interleaved rows + qx[0] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+0); + qx[1] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+1); + qx[2] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+2); + qx[3] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+3); +#ifndef HAVE_FANCY_SIMD + auto s0 = _mm256_sign_epi8(qx[0], qx[0]); + auto s1 = _mm256_sign_epi8(qx[1], qx[1]); + auto s2 = _mm256_sign_epi8(qx[2], qx[2]); + auto s3 = _mm256_sign_epi8(qx[3], qx[3]); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto y128 = _mm_loadu_si128((const __m128i*)q8y[iy]+ib); + auto y = MM256_SET_M128I(y128, y128); +#ifdef HAVE_FANCY_SIMD + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00)); + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55)); + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa)); + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff)); +#else + auto sumi1 = _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])); + auto sumi2 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])); + auto sumi3 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])); + auto sumi4 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])); + auto sumi12 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2)); + auto sumi34 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi3), _mm256_madd_epi16(m1, sumi4)); + acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(sumi12, sumi34)); +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto scale = _mm256_mul_ps(dx, _mm256_set1_ps(dy[iy])); +#ifdef HAVE_FANCY_SIMD + acc[iy] = _mm256_add_epi32(acc[iy], _mm256_set1_epi32(sy[iy])); +#endif + info.store(ix, iy, _mm256_mul_ps(scale, _mm256_cvtepi32_ps(acc[iy]))); + acc[iy] = _mm256_setzero_si256(); + } + } +} + +template +static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + GGML_ASSERT(n%32 == 0); + __m256i qx[2]; + __m256i acc[2*nrc_y] = {}; + float dy[nrc_y]; +#ifdef HAVE_FANCY_SIMD + int32_t sy[nrc_y]; +#else + __m256i sx[2]; + auto m1 = _mm256_set1_epi16(1); +#endif + const int8_t * q8y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + dy[iy] = dptr[0]; +#ifdef HAVE_FANCY_SIMD + auto iptr = (const int32_t *)(dptr+1); + sy[iy] = -127*iptr[0]; +#endif + q8y[iy] = (const int8_t *)(dptr + 2); + } + for (int ix = 0; ix < nrc_x; ++ix) { + auto dx = (const float *)((const char *)vx + ix*bx); + auto q8x = (const int8_t *)(dx + 2); + for (int i = 0; i < n/64; ++i) { + for (int j = 0; j < 2; ++j) { +#ifdef HAVE_FANCY_SIMD + qx[j] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + 2*i + j), _mm256_set1_epi8(127)); +#else + qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 2*i + j); + sx[j] = _mm256_sign_epi8(qx[j], qx[j]); +#endif + } + for (int iy = 0; iy < nrc_y; ++iy) { + for (int j = 0; j < 2; ++j) { +#ifdef HAVE_FANCY_SIMD + acc[2*iy+j] = _mm256_dpbusd_epi32(acc[2*iy+j], qx[j], _mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j)); +#else + auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j])); + acc[2*iy+j] = _mm256_add_epi32(acc[2*iy+j], _mm256_madd_epi16(m1, dot)); +#endif + } + } + } + if (int i = 2*(n/64); i < n/32) { +#ifdef HAVE_FANCY_SIMD + qx[0] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + i), _mm256_set1_epi8(127)); +#else + qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i); + sx[0] = _mm256_sign_epi8(qx[0], qx[0]); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { +#ifdef HAVE_FANCY_SIMD + acc[2*iy] = _mm256_dpbusd_epi32(acc[2*iy], qx[0], _mm256_loadu_si256((const __m256i *)q8y[iy] + i)); +#else + auto dot = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0])); + acc[2*iy] = _mm256_add_epi32(acc[2*iy], _mm256_madd_epi16(m1, dot)); +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = hsum_i32_8(_mm256_add_epi32(acc[2*iy], acc[2*iy+1])); +#ifdef HAVE_FANCY_SIMD + info.store(ix, iy, dx[0]*dy[iy]*(sumi+sy[iy])); +#else + info.store(ix, iy, dx[0]*dy[iy]*sumi); +#endif + acc[2*iy] = acc[2*iy+1] = _mm256_setzero_si256(); + } + } +} + +template +static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + GGML_ASSERT(n%32 == 0); + __m256i qx[4]; +#ifndef HAVE_FANCY_SIMD + __m256i sx[4]; + auto m1 = _mm256_set1_epi16(1); +#endif + __m256i acc[nrc_y] = {}; + float dy[nrc_y]; +#ifdef HAVE_FANCY_SIMD + int32_t sy[nrc_y]; +#endif + const int8_t * q8y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + dy[iy] = dptr[0]; +#ifdef HAVE_FANCY_SIMD + auto iptr = (const int32_t *)(dptr + 1); + sy[iy] = -127*iptr[0]; +#endif + q8y[iy] = (const int8_t *)(dptr + 2); + } + const int8_t * q8x[4]; + float dx[4]; + for (int ix = 0; ix < nrc_x; ix += 4) { + for (int kx = 0; kx < 4; ++kx) { + auto dptr = (const float *)((const char *)vx + (ix+kx)*bx); + dx[kx] = dptr[0]; + q8x[kx] = (const int8_t *)(dptr + 2); + } + for (int i = 0; i < n/32; ++i) { + for (int kx = 0; kx < 4; ++kx) qx[kx] = _mm256_loadu_si256((const __m256i *)q8x[kx] + i); + auto t0 = _mm256_unpacklo_epi32(qx[0], qx[1]); + auto t1 = _mm256_unpacklo_epi32(qx[2], qx[3]); + auto t2 = _mm256_unpackhi_epi32(qx[0], qx[1]); + auto t3 = _mm256_unpackhi_epi32(qx[2], qx[3]); +#ifdef HAVE_FANCY_SIMD + qx[0] = _mm256_add_epi8(_mm256_unpacklo_epi64(t0, t1), _mm256_set1_epi8(127)); + qx[1] = _mm256_add_epi8(_mm256_unpackhi_epi64(t0, t1), _mm256_set1_epi8(127)); + qx[2] = _mm256_add_epi8(_mm256_unpacklo_epi64(t2, t3), _mm256_set1_epi8(127)); + qx[3] = _mm256_add_epi8(_mm256_unpackhi_epi64(t2, t3), _mm256_set1_epi8(127)); +#else + qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]); + qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]); + qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]); + qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i *)q8y[iy] + i); +#ifdef HAVE_FANCY_SIMD + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00)); + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55)); + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa)); + acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff)); +#else + auto dot1 = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])); + auto dot2 = _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])); + auto dot3 = _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])); + auto dot4 = _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])); + auto dot12 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot1), _mm256_madd_epi16(m1, dot2)); + auto dot34 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot3), _mm256_madd_epi16(m1, dot4)); + acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(dot12, dot34)); +#endif + } + } + auto scales_x = _mm_loadu_ps(dx); + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = _mm_add_epi32(_mm256_castsi256_si128(acc[iy]), _mm256_extracti128_si256(acc[iy], 1)); +#ifdef HAVE_FANCY_SIMD + sumi = _mm_add_epi32(sumi, _mm_set1_epi32(sy[iy])); +#endif + auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[iy])); + info.store(ix, iy, _mm_mul_ps(scale, _mm_cvtepi32_ps(sumi))); + acc[iy] = _mm256_setzero_si256(); + } + } +} + #ifdef __AVX512BF16__ template static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { @@ -9114,6 +9342,33 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { #endif expected_typeB = GGML_TYPE_Q8_KR8; break; + case GGML_TYPE_Q8_KV: + assert (ne00 % 32 == 0); + mm.funcs[0] = mul_mat_q8_KV_q8_KV_1<1>; + mm.funcs[1] = mul_mat_q8_KV_q8_KV<2>; + mm.funcs[2] = mul_mat_q8_KV_q8_KV<3>; + mm.funcs[3] = mul_mat_q8_KV_q8_KV<4>; + mm.funcs[4] = mul_mat_q8_KV_q8_KV<5>; + mm.funcs[5] = mul_mat_q8_KV_q8_KV<6>; + mm.funcs[6] = mul_mat_q8_KV_q8_KV<7>; + mm.funcs[7] = mul_mat_q8_KV_q8_KV<8>; +#ifdef HAVE_FANCY_SIMD + mm.func16 = mul_mat_q8_KV_q8_KV<16>; +#endif + expected_typeB = GGML_TYPE_Q8_KV; + break; + case GGML_TYPE_Q8_KV_R8: + assert (ne00 % 32 == 0); + mm.funcs[0] = mul_mat_q8_KV_r8_q8_KV<1>; + mm.funcs[1] = mul_mat_q8_KV_r8_q8_KV<2>; + mm.funcs[2] = mul_mat_q8_KV_r8_q8_KV<3>; + mm.funcs[3] = mul_mat_q8_KV_r8_q8_KV<4>; + mm.funcs[4] = mul_mat_q8_KV_r8_q8_KV<5>; + mm.funcs[5] = mul_mat_q8_KV_r8_q8_KV<6>; + mm.funcs[6] = mul_mat_q8_KV_r8_q8_KV<7>; + mm.funcs[7] = mul_mat_q8_KV_r8_q8_KV<8>; + expected_typeB = GGML_TYPE_Q8_KV; + break; case GGML_TYPE_IQ4_K_R4: assert (ne00 % QK_K == 0); mm.funcs[0] = mul_mat_iq4_k_r4_q8_k<1>; @@ -13424,6 +13679,123 @@ void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& inf } } +static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(n%32 == 0); + int32x4_t acc[4] = {}; + auto dptr = (const float *)info.src1_row(0); + const float dy = dptr[0]; + auto q8y = (const int8_t *)(dptr + 2); + for (int ix = 0; ix < nrc_x; ++ix) { + auto dx = (const float *)((const char *)vx + ix*bx); + auto q8x = (const int8_t *)(dx + 2); + for (int i = 0; i < n/64; ++i) { + auto qx = vld1q_s8_x4(q8x + 64*i); + for (int j = 0; j < 4; ++j) { + acc[j] = ggml_vdotq_s32(acc[j], qx.val[j], vld1q_s8(q8y + 64*i + 16*j)); + } + } + if (int i = 2*(n/64); i < n/32) { + auto qx = vld1q_s8_x2(q8x + 32*i); + for (int j = 0; j < 2; ++j) { + acc[j] = ggml_vdotq_s32(acc[j], qx.val[j], vld1q_s8(q8y + 32*i + 16*j)); + } + } + acc[0] = vaddq_s32(acc[0], acc[1]); + acc[2] = vaddq_s32(acc[2], acc[3]); + acc[0] = vaddq_s32(acc[0], acc[2]); + info.store(ix, 0, dx[0]*dy*vaddvq_s32(acc[0])); + acc[0] = acc[1] = acc[2] = acc[3] = vdupq_n_s32(0); + } +} + +template +static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + GGML_ASSERT(n%16 == 0); + int8x16_t qx[4]; + int32x4_t acc[nrc_y] = {}; + float dy[nrc_y]; + const int8_t * q8y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + dy[iy] = dptr[0]; + q8y[iy] = (const int8_t *)(dptr + 2); + } + const int8_t * q8x[4]; + float dx[4]; + for (int ix = 0; ix < nrc_x; ix += 4) { + for (int kx = 0; kx < 4; ++kx) { + auto dptr = (const float *)((const char *)vx + (ix+kx)*bx); + dx[kx] = dptr[0]; + q8x[kx] = (const int8_t *)(dptr + 2); + } + for (int i = 0; i < n/16; ++i) { + for (int kx = 0; kx < 4; ++kx) qx[kx] = vld1q_s8(q8x[kx] + 16*i); + auto row01 = vtrnq_s32(qx[0], qx[1]); + auto row23 = vtrnq_s32(qx[2], qx[3]); + qx[0] = vtrn1q_s64(row01.val[0], row23.val[0]); + qx[1] = vtrn1q_s64(row01.val[1], row23.val[1]); + qx[2] = vtrn2q_s64(row01.val[0], row23.val[0]); + qx[3] = vtrn2q_s64(row01.val[1], row23.val[1]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8y[iy] + 16*i); + acc[iy] = vdotq_laneq_s32(acc[iy], qx[0], y, 0); + acc[iy] = vdotq_laneq_s32(acc[iy], qx[1], y, 1); + acc[iy] = vdotq_laneq_s32(acc[iy], qx[2], y, 2); + acc[iy] = vdotq_laneq_s32(acc[iy], qx[3], y, 3); + } + } + auto scales_x = vld1q_f32(dx); + for (int iy = 0; iy < nrc_y; ++iy) { + auto scale = vmulq_f32(scales_x, vdupq_n_f32(dy[iy])); + info.store(ix, iy, vmulq_f32(scale, vcvtq_f32_s32(acc[iy]))); + acc[iy] = vdupq_n_s32(0); + } + } +} + +template +void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + int32x4_t acc[2*nrc_y] = {}; + float dy[nrc_y]; + const int8_t * q8y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + dy[iy] = dptr[0]; + q8y[iy] = (const int8_t *)(dptr + 2); + } + for (int ix = 0; ix < nrc_x; ix += 8) { + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto q8x = (const int8_t *)(dptr + 8); + for (int ib = 0; ib < n/16; ++ib) { + auto q1 = vld1q_s8_x4(q8x + 128*ib + 0); + auto q2 = vld1q_s8_x4(q8x + 128*ib + 64); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8y[iy]+16*ib); + acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[0], y, 0); + acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[1], y, 0); + acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[2], y, 1); + acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[3], y, 1); + acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[0], y, 2); + acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[1], y, 2); + acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[2], y, 3); + acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[3], y, 3); + } + } + auto scale1_x = vld1q_f32(dptr+0); + auto scale2_x = vld1q_f32(dptr+4); + for (int iy = 0; iy < nrc_y; ++iy) { + auto scale_y = vdupq_n_f32(dy[iy]); + auto scale1 = vmulq_f32(scale1_x, scale_y); + auto scale2 = vmulq_f32(scale2_x, scale_y); + info.store(ix+0, iy, vmulq_f32(scale1, vcvtq_f32_s32(acc[2*iy+0]))); + info.store(ix+4, iy, vmulq_f32(scale2, vcvtq_f32_s32(acc[2*iy+1]))); + acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f); + } + } +} + void mul_mat_iq4_nl_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); Q8<1, block_q8_0_x4> q8(info); @@ -14000,6 +14372,16 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_k_r8_q8_k); expected_Btype = GGML_TYPE_Q8_KR8; break; + case GGML_TYPE_Q8_KV: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_KV_q8_KV); + m.funcs[0] = mul_mat_q8_KV_q8_KV_1; + m.func16 = mul_mat_q8_KV_q8_KV<16>; + expected_Btype = GGML_TYPE_Q8_KV; + break; + case GGML_TYPE_Q8_KV_R8: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_KV_r8_q8_KV); + expected_Btype = GGML_TYPE_Q8_KV; + break; case GGML_TYPE_IQ2_K_R4: SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_k_r4_q8_k); expected_Btype = GGML_TYPE_Q8_K; @@ -14347,13 +14729,49 @@ struct HelperF16 final : public BaseHelper { } }; +template struct block_q8_KV { + float d; + int s; + int8_t qs[D]; +}; + +template +struct HelperQ8KV final : public BaseHelper { + using Base = BaseHelper; + using block_q8 = block_q8_KV; + constexpr static int block_size_q = D; + HelperQ8KV(const char * data, int stride) : Base(data, stride) {} + + // Needed for v * softmax(k * q) + inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const { + auto q8 = (const block_q8_KV *)Base::lblock(l1); +#ifdef __aarch64__ + auto vd = F16::set1(q8->d); + auto qs = vld1_s8_x2(q8->qs + 8*i); + v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0]))); + v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1]))); +#else + auto vd = F16::set1(q8->d); +#ifdef HAVE_FANCY_SIMD + v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+0)))); + v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+1)))); +#else + v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(q8->qs+8*i+0))))); + v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(q8->qs+8*i+8))))); +#endif +#endif + } +}; + template struct HelperQ80 final : public BaseHelper { using Base = BaseHelper; #ifdef HAVE_FANCY_SIMD using block_q8 = block_q8_1; + constexpr static int block_size_q = QK8_1; #else using block_q8 = block_q8_0; + constexpr static int block_size_q = QK8_0; #endif HelperQ80(const char * data, int stride) : Base(data, stride) {} @@ -14397,23 +14815,33 @@ struct HelperQ80 final : public BaseHelper { y += D/QK8_1; } } + + static inline void convert(int nq, int stride_q, const float * q, block_q8_KV * y) { + for (int i = 0; i < nq; ++i) { + quantize_row_q8_KV(q, y, D); + q += stride_q; + ++y; + } + } }; template -struct HelperQ80R4 : public BaseHelper { +struct HelperQ80R8 : public BaseHelper { using Base = BaseHelper; #ifdef __AVX2__ + constexpr static int block_size_q = QK8_1; using block_q8 = block_q8_1; #else + constexpr static int block_size_q = QK8_0; using block_q8 = block_q8_0; #endif - HelperQ80R4(int nk, const HelperQ80& q8) : Base(q8.data, q8.stride) { + HelperQ80R8(int nk, const HelperQ80& q8) : Base(q8.data, q8.stride) { r4 = repack(nk, q8); Base::data = (const char *)r4.data(); Base::stride = (D/QK8_0)*sizeof(block_q8_0); } - static std::vector repack(int nk, const HelperQ80 q8) { + static std::vector repack(int nk, const HelperQ80& q8) { static_assert(D%QK8_0 == 0); GGML_ASSERT(nk%8 == 0); constexpr int nblock = D/QK8_0; @@ -14512,10 +14940,107 @@ struct HelperQ80R4 : public BaseHelper { std::vector r4; }; +// TODO: unite this with the above +template +struct HelperQ8KVR8 : public BaseHelper { + using Base = BaseHelper; + constexpr static int block_size_q = D; + using block_q8 = block_q8_KV; + + struct block_q8_KV_r8 { + float d[8]; + int8_t qs[8*D]; + }; + + HelperQ8KVR8(int nk, const HelperQ8KV& q8) : Base(q8.data, q8.stride) { + r4 = repack(nk, q8); + Base::data = (const char *)r4.data(); + Base::stride = sizeof(block_q8_KV_r8)/8; + } + + static std::vector repack(int nk, const HelperQ8KV& q8) { + static_assert(D%32 == 0); + GGML_ASSERT(nk%8 == 0); + std::vector result(nk/8); + auto y = result.data(); +#ifdef __ARM_NEON + int8x16x2_t m0, m1, m2, m3; +#endif + const int8_t * x8[8]; + for (int ix = 0; ix < nk/8; ++ix) { + for (int k = 0; k < 8; ++k) { + auto dptr = (const float *)(q8.data + (8*ix + k)*q8.stride); + y[ix].d[k] = dptr[0]; + x8[k] = (const int8_t *)(dptr + 2); + } + for (int ib = 0; ib < D/16; ++ib) { +#ifdef __AVX2__ + auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4]+ib), _mm_loadu_si128((const __m128i *)x8[0]+ib)); + auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5]+ib), _mm_loadu_si128((const __m128i *)x8[1]+ib)); + auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6]+ib), _mm_loadu_si128((const __m128i *)x8[2]+ib)); + auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7]+ib), _mm_loadu_si128((const __m128i *)x8[3]+ib)); + auto t0 = _mm256_unpacklo_epi32(m0, m1); + auto t1 = _mm256_unpacklo_epi32(m2, m3); + auto t2 = _mm256_unpackhi_epi32(m0, m1); + auto t3 = _mm256_unpackhi_epi32(m2, m3); + m0 = _mm256_unpacklo_epi64(t0, t1); + m1 = _mm256_unpackhi_epi64(t0, t1); + m2 = _mm256_unpacklo_epi64(t2, t3); + m3 = _mm256_unpackhi_epi64(t2, t3); +#ifdef HAVE_FANCY_SIMD + m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); + m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); + m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); + m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); +#endif + _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+0, m0); + _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+1, m1); + _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+2, m2); + _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+3, m3); +#elif defined __ARM_NEON + // TODO + m0.val[0] = vld1q_s8(x8[0]+16*ib); m0.val[1] = vld1q_s8(x8[4]+16*ib); + m1.val[0] = vld1q_s8(x8[1]+16*ib); m1.val[1] = vld1q_s8(x8[5]+16*ib); + m2.val[0] = vld1q_s8(x8[2]+16*ib); m2.val[1] = vld1q_s8(x8[6]+16*ib); + m3.val[0] = vld1q_s8(x8[3]+16*ib); m3.val[1] = vld1q_s8(x8[7]+16*ib); + auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0])); + auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0])); + m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1])); + row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1])); + m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + vst1q_s8_x2(y[ix].qs + 0 + 128*ib, m0); + vst1q_s8_x2(y[ix].qs + 32 + 128*ib, m1); + vst1q_s8_x2(y[ix].qs + 64 + 128*ib, m2); + vst1q_s8_x2(y[ix].qs + 96 + 128*ib, m3); +#else + // TODO + for (int l = 0; l < 4; ++l) { + for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) { + y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0]; + y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16]; + } + } +#endif + } + } + return result; + } + + std::vector r4; +}; + template struct HelperQ40 final : public BaseHelper { using Base = BaseHelper; using block_q8 = block_q8_0; + constexpr static int block_size_q = QK8_0; HelperQ40(const char * data, int stride) : Base(data, stride) {} // Needed for v * softmax(k * q) @@ -14559,6 +15084,7 @@ template struct HelperQ41 final : public BaseHelper { using Base = BaseHelper; using block_q8 = block_q8_1; + constexpr static int block_size_q = QK8_1; HelperQ41(const char * data, int stride) : Base(data, stride) {} // Needed for v * softmax(k * q) @@ -14649,8 +15175,10 @@ template struct HelperQ60 final : public BaseHelper { #ifdef __aarch64__ using block_q8 = block_q8_0; + constexpr static int block_size_q = QK8_0; #else using block_q8 = block_q8_1; + constexpr static int block_size_q = QK8_1; #endif using Base = BaseHelper; HelperQ60(const char * data, int stride) : Base(data, stride) {} @@ -15071,9 +15599,9 @@ struct FlashQKV { } inline void normalize_and_store(const FlashMS& fms, int j, const qkv_cache_t * R, float * qkv) const { - GGML_ASSERT(fms.S[j] > 0); - auto norm = F16::set1(1/fms.S[j]); - //auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f); + //GGML_ASSERT(fms.S[j] > 0); + //auto norm = F16::set1(1/fms.S[j]); + auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f); for (int i = 0; i < D/F16::block_size; ++i) { auto r = F16::load(R + F16::block_size*i); F16::store(qkv + F16::block_size*i, F16::mul(norm, r)); @@ -15357,13 +15885,29 @@ struct FlashQKfp32 { #endif #endif } - else if constexpr (std::is_same_v>) { + else if constexpr (std::is_same_v>) { +#ifdef __aarch64__ + if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16); + if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1, 1); + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq); +#else +#ifdef HAVE_FANCY_SIMD + if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16); +#endif + if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1<1>, 1); + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq); +#endif + } + else if constexpr (std::is_same_v>) { #ifdef __aarch64__ MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_0, nq); #else MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r8_q8_1, nq); #endif } + else if constexpr (std::is_same_v>) { + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_r8_q8_KV, nq); + } else if constexpr (std::is_same_v>) { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_0_q8_0(q_step); - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; + DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr}; for (int iq = 0; iq < q_step/nrc_q; ++iq) { mul_mat(D, kh.block, kh.stride, info, k_step); info.cur_y += nrc_q; @@ -15428,7 +15972,7 @@ struct FlashQKfp32 { static inline void mul_mask_kq(int nq, const KHelper& kh, int stride_m, const block_q8 * q, const char * mask, FlashMS& fms) { auto [mul_mat, nrc_q] = mul_mat_kernel(nq); - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; + DataInfo info{fms.cache, (const char *)q, k_step, (D/KHelper::block_size_q)*sizeof(block_q8), 0, 1, nullptr}; for (int iq = 0; iq < nq/nrc_q; ++iq) { mul_mat(D, kh.block, kh.stride, info, k_step); info.cur_y += nrc_q; @@ -15516,7 +16060,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, FlashMS& fms, FlashQKV& fqkv, const float * q, const char * mask, float * qkv) { - typename KHelper::block_q8 q8[q_step*(Dk/QK8_0)]; + typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)]; #if FA_TIMING Perf perf(false); #endif @@ -15613,12 +16157,28 @@ struct FlashAttn { if (nq1 >= 8) { #if FA_TIMING auto t1 = Perf::cur_time(); - HelperQ80R4 khr4(nk1, kh); + HelperQ80R8 khr4(nk1, kh); Perf::instance().accum(4, t1); #else - HelperQ80R4 khr4(nk1, kh); + HelperQ80R8 khr4(nk1, kh); #endif - compute_helper_q, VHelper, FlashQKfp32>( + compute_helper_q, VHelper, FlashQKfp32>( + khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + } else{ + compute_helper_q>( + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + } + } + else if constexpr (std::is_same_v>) { + if (nq1 >= 8) { +#if FA_TIMING + auto t1 = Perf::cur_time(); + HelperQ8KVR8 khr4(nk1, kh); + Perf::instance().accum(4, t1); +#else + HelperQ8KVR8 khr4(nk1, kh); +#endif + compute_helper_q, VHelper, FlashQKfp32>( khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); } else{ compute_helper_q>( @@ -16142,6 +16702,10 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, HelperQ80 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; + case GGML_TYPE_Q8_KV: { + HelperQ8KV vh(v, stride_v); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + } break; case GGML_TYPE_Q6_0: { HelperQ60 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); @@ -16179,6 +16743,10 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperQ80 kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; + case GGML_TYPE_Q8_KV: { + HelperQ8KV kh(k, stride_k); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + } break; case GGML_TYPE_Q6_0: { HelperQ60 kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); @@ -16210,7 +16778,7 @@ inline bool flash_attn_is_supported(ggml_type type) { if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true; #else - if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0) return true; + if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_Q8_KV) return true; #endif return false; } diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 24b49d89..7b777a1f 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -2967,6 +2967,103 @@ void iqk_quantize_row_q8_K128(const float * x, void * vy, int64_t k) { } #endif } +// TODO: merge this with the above template +void iqk_quantize_row_q8_KV(const float * x, void * vy, int64_t k) { + assert(k % 32 == 0); + auto dptr = (float *)vy; + auto q8 = (int8_t *)(dptr + 2); +#ifdef __AVX2__ + const __m256 signBit = _mm256_set1_ps(-0.0f); + const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7); + __m256 maxAbs = _mm256_setzero_ps(); + for (int ib = 0; ib < k/8; ++ib) { + const __m256 v = _mm256_loadu_ps(x + 8*ib); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps(signBit, v)); + } + const float maxScalar = hmax_f32_8(maxAbs); + if (!maxScalar) { + dptr[0] = dptr[1] = 0; + std::memset(q8, 0, k*sizeof(int8_t)); + return; + } + dptr[0] = maxScalar / 127.f; + auto mul = _mm256_set1_ps(1/dptr[0]); + auto isum = _mm256_setzero_si256(); + for (int i = 0; i < k/32; i++) { + __m256 v0 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 0)); + __m256 v1 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 8)); + __m256 v2 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 16)); + __m256 v3 = _mm256_mul_ps(mul, _mm256_loadu_ps(x + 32*i + 24)); + v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST); + v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST); + v2 = _mm256_round_ps(v2, _MM_ROUND_NEAREST); + v3 = _mm256_round_ps(v3, _MM_ROUND_NEAREST); + __m256i i0 = _mm256_cvtps_epi32(v0); + __m256i i1 = _mm256_cvtps_epi32(v1); + __m256i i2 = _mm256_cvtps_epi32(v2); + __m256i i3 = _mm256_cvtps_epi32(v3); + isum = _mm256_add_epi32(isum, _mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); + i0 = _mm256_packs_epi32( i0, i1 ); + i2 = _mm256_packs_epi32( i2, i3 ); + i0 = _mm256_packs_epi16( i0, i2 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + _mm256_storeu_si256((__m256i *)q8, i0); + q8 += 32; + } + auto iptr = (int32_t *)(dptr + 1); + iptr[0] = hsum_i32_8(isum); +#elif defined __ARM_NEON + int32x4_t ival[8]; + auto vmax = vdupq_n_f32(0.f); + for (int j = 0; j < k; j += 4) { + vmax = vmaxq_f32(vmax, vabsq_f32(vld1q_f32(x + j))); + } + auto smax = vmaxvq_f32(vmax); + if (!smax) { + dptr[0] = dptr[1] = 0; + std::memset(q8, 0, k*sizeof(int8_t)); + return; + } + dptr[0] = smax/127; + auto vid = vdupq_n_f32(1/dptr[0]); + auto isum = vdupq_n_s32(0); + for (int ib = 0; ib < k/32; ++ib) { + auto xb = x + 32*ib; + for (int k = 0; k < 8; ++k) { + auto val = vld1q_f32(xb + 4*k); + ival[k] = vcvtnq_s32_f32(vmulq_f32(val, vid)); + isum = vaddq_s32(isum, ival[k]); + } + for (int k = 0; k < 4; ++k) { + auto i16 = vcombine_s16(vmovn_s32(ival[2*k+0]), vmovn_s32(ival[2*k+1])); + vst1_s8(q8, vmovn_s16(i16)); + q8 += 8; + } + } + auto iptr = (int32_t *)(dptr + 1); + iptr[0] = vaddvq_s32(isum); +#else + float amax = 0; + for (int j = 0; j < k; ++j) { + float ax = std::abs(x[j]); + amax = std::max(amax, ax); + } + if (!amax) { + dptr[0] = dptr[1] = 0; + std::memset(q8, 0, k*sizeof(int8_t)); + return; + } + dptr[0] = amax/127; + float id = 1/dptr[0]; + int isum = 0; + for (int i = 0; i < k; i++) { + q8[i] = nearest_int(id*x[i]); + isum += q8[i]; + } + auto iptr = (int32_t *)(dptr + 1); + iptr[0] = isum; +#endif +} } void quantize_row_q8_K128(const float * x, void * vy, int64_t k) { @@ -3886,7 +3983,7 @@ static void repack_q8_0(int nrows, int n_per_row, const block_q8_0 * x, block_q8 #ifdef HAVE_FANCY_SIMD static void modify_q8_0_r8(int64_t k, char * cy) { - auto y = (block_iq4_nl_r8 *)cy; + auto y = (block_q8_0_r8 *)cy; int nb = k/(32*8); for (int ib = 0; ib < nb; ++ib) { for (int l = 0; l < 4; ++l) { @@ -5412,6 +5509,150 @@ void vec_dot_q8_k_r8_q8_k(int n, float * s, size_t bs, const void * vx, size_t b GGML_UNUSED(by); } +// +// ========================================= q8_KV_r8 +// + +void quantize_row_q8_KV_r8_ref(const float * x, void * y, int64_t k) { + quantize_q8_KV_r8(x, y, 8, k/8, nullptr); +} + +void quantize_row_q8_KV_r8(const float * x, void * y, int64_t k) { + quantize_q8_KV_r8(x, y, 8, k/8, nullptr); +} + +static void repack_q8_KV(int nrows, int n_per_row, const char * cx, char * cy, [[maybe_unused]] bool online) { + GGML_ASSERT(nrows%8 == 0); + GGML_ASSERT(n_per_row%16 == 0); + auto row_size_x = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row); + auto row_size_y = ggml_row_size(GGML_TYPE_Q8_KV_R8, n_per_row); + const int8_t * x8[8]; +#ifdef __ARM_NEON + int8x16x2_t m0, m1, m2, m3; +#endif + for (int row = 0; row < nrows; row += 8) { + auto dy = (float *)cy; + auto qy = (int8_t *)(dy + 8); + for (int k = 0; k < 8; ++k) { + auto dx = (const float *)(cx + k*row_size_x); + dy[k] = dx[0]; + x8[k] = (const int8_t *)(dx + 2); + } + for (int ib = 0; ib < n_per_row/16; ++ib) { +#ifdef __AVX2__ +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) + auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4]+ib), _mm_loadu_si128((const __m128i *)x8[0]+ib)); + auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5]+ib), _mm_loadu_si128((const __m128i *)x8[1]+ib)); + auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6]+ib), _mm_loadu_si128((const __m128i *)x8[2]+ib)); + auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7]+ib), _mm_loadu_si128((const __m128i *)x8[3]+ib)); + auto t0 = _mm256_unpacklo_epi32(m0, m1); + auto t1 = _mm256_unpacklo_epi32(m2, m3); + auto t2 = _mm256_unpackhi_epi32(m0, m1); + auto t3 = _mm256_unpackhi_epi32(m2, m3); + m0 = _mm256_unpacklo_epi64(t0, t1); + m1 = _mm256_unpackhi_epi64(t0, t1); + m2 = _mm256_unpacklo_epi64(t2, t3); + m3 = _mm256_unpackhi_epi64(t2, t3); +#ifdef HAVE_FANCY_SIMD + if (online) { + m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); + m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); + m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); + m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); + } +#endif + _mm256_storeu_si256((__m256i *)qy + 4*ib+0, m0); + _mm256_storeu_si256((__m256i *)qy + 4*ib+1, m1); + _mm256_storeu_si256((__m256i *)qy + 4*ib+2, m2); + _mm256_storeu_si256((__m256i *)qy + 4*ib+3, m3); +#elif defined __ARM_NEON + m0.val[0] = vld1q_s8(x8[0]+16*ib); m0.val[1] = vld1q_s8(x8[4]+16*ib); + m1.val[0] = vld1q_s8(x8[1]+16*ib); m1.val[1] = vld1q_s8(x8[5]+16*ib); + m2.val[0] = vld1q_s8(x8[2]+16*ib); m2.val[1] = vld1q_s8(x8[6]+16*ib); + m3.val[0] = vld1q_s8(x8[3]+16*ib); m3.val[1] = vld1q_s8(x8[7]+16*ib); + auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0])); + auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0])); + m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1])); + row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1])); + m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + vst1q_s8_x2(qy + 0 + 128*ib, m0); + vst1q_s8_x2(qy + 32 + 128*ib, m1); + vst1q_s8_x2(qy + 64 + 128*ib, m2); + vst1q_s8_x2(qy + 96 + 128*ib, m3); +#else + // TODO + for (int l = 0; l < 4; ++l) { + for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) { + y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0]; + y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16]; + } + } +#endif + + } + cx += 8*row_size_x; + cy += online ? 8*row_size_x : 8*row_size_y; + //So, if we are run-time-repacking (online = true) we don't want to change the stride, so we just leave some unused space at the end of each row + } +} +#ifdef HAVE_FANCY_SIMD +static void modify_q8_KV_r8(int64_t k, char * cy) { + int8_t * q8 = (int8_t *)(cy + 8*sizeof(float)); + for (int j = 0; j < k; ++j) q8[j] += 127; +} +#endif + +size_t quantize_q8_KV_r8(const float * src, void * dst, int64_t nrows, int64_t n_per_row, [[maybe_unused]] const float * imatrix) { + GGML_ASSERT(nrows%8 == 0); + GGML_ASSERT(n_per_row%16 == 0); + char * qcur = (char *)dst; + auto row_size_0 = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row); + auto row_size_1 = ggml_row_size(GGML_TYPE_Q8_KV_R8, n_per_row); + std::vector qtmp(8*row_size_0); + for (int row = 0; row < nrows; row += 8) { + quantize_q8_KV(src, (void *)qtmp.data(), 8, n_per_row, imatrix); + repack_q8_KV(8, n_per_row, qtmp.data(), qcur, false); + qcur += 8*row_size_1; + src += 8*n_per_row; + } + return nrows*row_size_1; +} + +void dequantize_row_q8_KV_r8(const void * vx, float * y, int64_t k) { + auto n_per_row = k/8; + float * y8[8]; + for (int k = 0; k < 8; ++k) y8[k] = y + n_per_row*k; + auto dptr = (const float *)vx; + auto q8 = (const int8_t *)(dptr + 8); + for (int ib = 0; ib < n_per_row/16; ++ib) { + for (int k = 0; k < 8; ++k) { + for (int l = 0; l < 4; ++l) { + for (int i = 0; i < 4; ++i) y8[k][16*ib + 4*l + i] = dptr[k] * q8[128*ib + 32*l + 4*k + i]; + } + } + } +} + +void vec_dot_q8_KV_r8_q8_KV(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q8_KV_R8, vx, 0, GGML_TYPE_Q8_KV, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + // // ========================================= bf16_r4 // @@ -6450,6 +6691,47 @@ void vec_dot_iq1_m_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t GGML_UNUSED(by); } +void quantize_row_q8_KV(const float * x, void * vy, int64_t k) { + iqk_quantize_row_q8_KV(x, vy, k); +} + +void quantize_row_q8_KV_ref(const float * x, void * y, int64_t k) { + quantize_row_q8_KV(x, y, k); +} + +size_t quantize_q8_KV(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + (void)imatrix; + auto row_size = ggml_row_size(GGML_TYPE_Q8_KV, n_per_row); + auto q = (char *)dst; + for (int row = 0; row < nrows; ++row) { + quantize_row_q8_KV(src, q, n_per_row); + src += n_per_row; + q += row_size; + } + return row_size*nrows; +} + +void dequantize_row_q8_KV(const void * x, float * y, int64_t k) { + auto dptr = (const float *)x; + float d = dptr[0]; + auto q8 = (const int8_t *)(dptr + 2); + for (int j = 0; j < k; ++j) y[j] = d * q8[j]; +} + +void vec_dot_q8_KV_q8_KV(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q8_KV, vx, 0, GGML_TYPE_Q8_KV, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + + //================================================ namespace { @@ -6472,8 +6754,9 @@ bool iqk_modify_tensor(struct ggml_tensor * tensor) { { GGML_TYPE_Q4_0_R8, {modify_q4_0_r8, 8} }, #endif #ifdef HAVE_FANCY_SIMD - { GGML_TYPE_Q8_0_R8, {modify_q8_0_r8, 8} }, - { GGML_TYPE_Q8_K_R8, {modify_q8_k_r8, 8} }, + { GGML_TYPE_Q8_0_R8, {modify_q8_0_r8, 8} }, + { GGML_TYPE_Q8_K_R8, {modify_q8_k_r8, 8} }, + { GGML_TYPE_Q8_KV_R8, {modify_q8_KV_r8, 8} }, #endif }; auto it = k_mod_map.find(tensor->type); @@ -6532,6 +6815,7 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) { { GGML_TYPE_Q6_0, { GGML_TYPE_Q6_0_R4, 4, (Repack::repack_func)repack_q6_0} }, { GGML_TYPE_Q8_0, { GGML_TYPE_Q8_0_R8, 8, (Repack::repack_func)repack_q8_0} }, { GGML_TYPE_Q8_K, { GGML_TYPE_Q8_K_R8, 8, (Repack::repack_func)repack_q8_k} }, + { GGML_TYPE_Q8_KV, { GGML_TYPE_Q8_KV_R8, 8, (Repack::repack_func)repack_q8_KV} }, #ifdef __AVX512BF16__ { GGML_TYPE_BF16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16}}, { GGML_TYPE_F16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16} }, diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 97719361..76fbac3b 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -217,6 +217,18 @@ size_t quantize_q8_k_r8(const float * GGML_RESTRICT src, void * GGML_RESTRICT ds void dequantize_row_q8_k_r8(const block_q8_k_r8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_q8_k_r8_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_q8_KV_ref(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_KV(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_q8_KV(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_q8_KV(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_q8_KV_q8_KV(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_q8_KV_r8_ref(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_KV_r8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_q8_KV_r8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_q8_KV_r8(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_q8_KV_r8_q8_KV(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void iqk_quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void quantize_row_q8_K64_ref(const float * GGML_RESTRICT x, block_q8_K64 * GGML_RESTRICT y, int64_t k); void quantize_row_q8_K64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); diff --git a/include/llama.h b/include/llama.h index 39251d35..b5ad65e7 100644 --- a/include/llama.h +++ b/include/llama.h @@ -180,6 +180,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ3_KL = 146, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ2_KS = 147, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_KSS = 148, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q8_KV = 149, // except 1d tensors // LLAMA_FTYPE_MOSTLY_Q4_0_R8 = 202, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q8_0_R8 = 207, // except 1d tensors @@ -206,6 +207,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ4_K_R4 = 340, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ5_K_R4 = 341, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_KS_R4 = 345, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q8_KV_R8 = 398, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file diff --git a/src/llama.cpp b/src/llama.cpp index 8c4a966d..0257a0a3 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3180,6 +3180,10 @@ static bool llama_kv_cache_init( for (int i = 0; i < (int) n_layer; i++) { const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + const uint32_t n_head = hparams.n_head(i); + const uint32_t n_head_kv = hparams.n_head_kv(i); + const uint32_t n_embd_head_k= hparams.n_embd_head_k; + struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); ggml_tensor * k; @@ -3201,7 +3205,8 @@ static bool llama_kv_cache_init( const uint32_t kv_lora_rank = hparams.n_lora_kv; LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); #if MLA_USE_TRANSPOSED_CACHE - ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size); + ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size); + //ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size); #else ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_v, (kv_lora_rank + n_embd_head_qk_rope)*kv_size); #endif @@ -3215,7 +3220,10 @@ static bool llama_kv_cache_init( n_mla++; } else { - k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + //printf("Creating cache tensors:\n"); + //printf("n_embd_k_gqa = %d, kv_size = %d, n_head = %d, n_head_kv = %d, n_embd_head_k = %d\n", (int)n_embd_k_gqa, (int)kv_size, (int)n_head, (int)n_head_kv, (int)n_embd_head_k); + //k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size); v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); @@ -4002,6 +4010,7 @@ struct llama_model_loader { case GGML_TYPE_Q5_1: ftype = LLAMA_FTYPE_MOSTLY_Q5_1; break; case GGML_TYPE_Q6_0: ftype = LLAMA_FTYPE_MOSTLY_Q6_0; break; case GGML_TYPE_Q8_0: ftype = LLAMA_FTYPE_MOSTLY_Q8_0; break; + case GGML_TYPE_Q8_KV: ftype = LLAMA_FTYPE_MOSTLY_Q8_KV; break; case GGML_TYPE_Q2_K: ftype = LLAMA_FTYPE_MOSTLY_Q2_K; break; case GGML_TYPE_Q3_K: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M; break; case GGML_TYPE_Q3_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_R4; break; @@ -4012,6 +4021,7 @@ struct llama_model_loader { case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break; case GGML_TYPE_Q6_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q6_K_R4; break; case GGML_TYPE_Q8_K_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_K_R8; break; + case GGML_TYPE_Q8_KV_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_KV_R8; break; case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break; case GGML_TYPE_IQ2_XXS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4; break; case GGML_TYPE_IQ2_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS; break; @@ -4730,6 +4740,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; case LLAMA_FTYPE_MOSTLY_Q6_0: return "Q6_0"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; + case LLAMA_FTYPE_MOSTLY_Q8_KV: return "Q8_KV"; case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q2_K_R4: return "Q2_K_R4"; case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; @@ -4746,6 +4757,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; case LLAMA_FTYPE_MOSTLY_Q6_K_R4: return "Q6_K_R4"; case LLAMA_FTYPE_MOSTLY_Q8_K_R8: return "Q8_K_R8"; + case LLAMA_FTYPE_MOSTLY_Q8_KV_R8: return "Q8_KV_R8"; case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4:return "IQ2_XXS_R4 - 2.0625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; @@ -8283,11 +8295,20 @@ static void llm_build_kv_store( const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + const int64_t n_head = hparams.n_head(il); + const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_embd_head_k = hparams.n_embd_head_k; + const int64_t n_embd_head_v = hparams.n_embd_head_v; + GGML_ASSERT(kv.size == n_ctx); - struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, - (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); - cb(k_cache_view, "k_cache_view", il); + //struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, + // (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); + //cb(k_cache_view, "k_cache_view", il); + + auto k_row_size = ggml_row_size(kv.k_l[il]->type, n_embd_head_k); + ggml_tensor * k_cache_view = ggml_view_2d(ctx, kv.k_l[il], n_embd_head_k, n_tokens*n_head_kv, + k_row_size, k_row_size*n_head_kv*kv_head); // note: storing RoPE-ed version of K in the KV cache ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); @@ -8706,7 +8727,7 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * k = ggml_view_3d(ctx, kv.k_l[il], n_embd_head_k, n_kv, n_head_kv, - ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv.k_l[il]->type, n_embd_head_k)*n_head_kv, //n_embd_k_gqa), ggml_row_size(kv.k_l[il]->type, n_embd_head_k), 0); cb(k, "k", il); @@ -13507,8 +13528,9 @@ struct llm_build_context { ggml_tensor * kvr = ggml_concat(ctx0, kv_compressed, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), 0); cb(kvr, "kvr", il); - ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*(kv_lora_rank + n_embd_head_qk_rope), - ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope)*kv_head); + auto row_size = ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope); + ggml_tensor * kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_tokens, + row_size, row_size*kv_head); ggml_build_forward_expand(gf, ggml_cpy(ctx0, kvr, kv_cache_view)); ggml_tensor * kv_cache = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank + n_embd_head_qk_rope, n_kv, @@ -16164,7 +16186,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type = GGML_TYPE_IQ5_K; } else if (new_type != GGML_TYPE_Q8_0 && new_type != GGML_TYPE_Q8_0_R8 && new_type != GGML_TYPE_IQ6_K && new_type != GGML_TYPE_Q6_K_R4 && - new_type != GGML_TYPE_Q8_K_R8) { + new_type != GGML_TYPE_Q8_K_R8 && new_type != GGML_TYPE_Q8_KV && new_type != GGML_TYPE_Q8_KV_R8) { new_type = GGML_TYPE_Q6_K; } } @@ -16218,6 +16240,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (new_type == GGML_TYPE_Q8_K_R8) { new_type = GGML_TYPE_Q8_0; } + else if (new_type == GGML_TYPE_Q8_KV_R8) { + new_type = GGML_TYPE_Q8_0; + } else if (new_type == GGML_TYPE_IQ2_K_R4) { new_type = GGML_TYPE_IQ2_K; } @@ -16728,6 +16753,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break; case LLAMA_FTYPE_MOSTLY_Q6_0: default_type = GGML_TYPE_Q6_0; break; case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break; + case LLAMA_FTYPE_MOSTLY_Q8_KV:default_type = GGML_TYPE_Q8_KV;break; case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; case LLAMA_FTYPE_MOSTLY_BF16_R16: default_type = GGML_TYPE_BF16_R16; break; @@ -16751,6 +16777,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q6_K: default_type = GGML_TYPE_Q6_K; break; case LLAMA_FTYPE_MOSTLY_Q6_K_R4: default_type = GGML_TYPE_Q6_K_R4; break; case LLAMA_FTYPE_MOSTLY_Q8_K_R8: default_type = GGML_TYPE_Q8_K_R8; break; + case LLAMA_FTYPE_MOSTLY_Q8_KV_R8: default_type = GGML_TYPE_Q8_KV_R8; break; case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break; case LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4:default_type = GGML_TYPE_IQ2_XXS_R4; break; case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break; @@ -17194,6 +17221,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q8_0; else chunk_size_multiplier = 8; } + else if (new_type == GGML_TYPE_Q8_KV_R8) { + if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q8_0; + else chunk_size_multiplier = 8; + } else if (new_type == GGML_TYPE_IQ2_BN_R4) { if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ2_BN; else chunk_size_multiplier = 4; From 498a582919f3955fee9ba4239d5f7a298a42425d Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 20 Feb 2025 12:41:45 +0200 Subject: [PATCH 012/132] Optimized GEMM/GEMV for IQ1_S (#212) * Adding iq1_s to iqk_mul_mat (Zen4) * iq1_s: slightly better on Zen4 * iq1_s: AVX2 * iq1s: NEON --------- Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_mul_mat.cpp | 160 +++++++++++++++++++++++++++++++++++ 1 file changed, 160 insertions(+) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 3bfded73..33e0a4a7 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -3666,6 +3666,85 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI } } +// sum[ qy_i * ls_k * (qx_i - 1+/-delta_k)] +// = sum[qy_i * qx_i * ls_k] - 1/8*sum[qy_i * ls_k * (8+/-o_k)] +// = 1/8 * ( sum[qy_i * qx_i * 8*ls+k] - sum[qy_i * ls_k * (8+/-o_k)] ) + +template +static void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + Q8 q8(info); + __m256i qx[8]; + __m256i scales[4]; + __m256 acc[nrc_y] = {}; + auto delta_mask = _mm_set1_epi16(-32768); // to avoid stupid overflow warnings when using 0x8000 + __m256i shuffle0 = _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100); + for (int ix = 0; ix < nrc_x; ++ix) { + auto iq1s = (const block_iq1_s *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < n/QK_K; ++ibl) { + float d = GGML_FP16_TO_FP32(iq1s[ibl].d); + auto qhb = _mm_loadu_si128((const __m128i *)iq1s[ibl].qh); + auto scales128 = _mm_and_si128(_mm_srli_epi16(qhb, 12), _mm_set1_epi16(7)); + scales128 = _mm_add_epi16(_mm_slli_epi16(scales128, 1), _mm_set1_epi16(1)); +#ifdef HAVE_FANCY_SIMD + auto mask = _mm_cmpeq_epi16_mask(_mm_and_si128(qhb, delta_mask), delta_mask); + auto deltas128 = _mm_mask_blend_epi16(mask, _mm_set1_epi16(-7), _mm_set1_epi16(-9)); +#else + auto mask = _mm_cmpeq_epi16(_mm_and_si128(qhb, delta_mask), delta_mask); + auto deltas128 = _mm_or_si128(_mm_and_si128(mask, _mm_set1_epi16(-9)), _mm_andnot_si128(mask, _mm_set1_epi16(-7))); +#endif + deltas128 = _mm_mullo_epi16(scales128, deltas128); + scales128 = _mm_slli_epi16(scales128, 3); + auto deltas_l = _mm_unpacklo_epi16(deltas128, deltas128); + auto deltas_h = _mm_unpackhi_epi16(deltas128, deltas128); + auto deltas = MM256_SET_M128I(deltas_h, deltas_l); // blocks 0,0, 1,1, 2,2, ..., 7,7 + auto all_scales = MM256_SET_M128I(scales128, scales128); + auto shuffle = shuffle0; + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + scales[ib64] = _mm256_shuffle_epi8(all_scales, shuffle); + shuffle = _mm256_add_epi8(shuffle, _mm256_set1_epi8(4)); + } + const uint8_t * qs = iq1s[ibl].qs; + const uint16_t * qh = iq1s[ibl].qh; + for (int ib = 0; ib < QK_K/32; ib += 2) { + qx[ib+0] = _mm256_set_epi64x(iq1s_grid_us[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid_us[qs[2] | ((qh[ib+0] << 2) & 0x700)], + iq1s_grid_us[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid_us[qs[0] | ((qh[ib+0] << 8) & 0x700)]); + qx[ib+1] = _mm256_set_epi64x(iq1s_grid_us[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid_us[qs[6] | ((qh[ib+1] << 2) & 0x700)], + iq1s_grid_us[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid_us[qs[4] | ((qh[ib+1] << 8) & 0x700)]); + qs += 8; + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto bsums = q8.load_bsums(iy, ibl); + auto sumi = _mm256_setzero_si256(); + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + auto qy1 = q8.load_quants(iy, ibl, 2*ib64+0); + auto qy2 = q8.load_quants(iy, ibl, 2*ib64+1); +#ifdef HAVE_FANCY_SIMD + auto dot1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2*ib64+0], qy1); + auto dot2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2*ib64+1], qy2); + sumi = _mm256_dpwssd_epi32(sumi, scales[ib64], _mm256_packs_epi32(dot1, dot2)); +#else + auto dot1 = _mm256_maddubs_epi16(qx[2*ib64+0], qy1); + auto dot2 = _mm256_maddubs_epi16(qx[2*ib64+1], qy2); + auto dot = _mm256_add_epi16(_mm256_unpacklo_epi64(dot1, dot2), _mm256_unpackhi_epi64(dot1, dot2)); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(scales[ib64], dot)); +#endif + } +#ifdef HAVE_FANCY_SIMD + sumi = _mm256_dpwssd_epi32(sumi, bsums, deltas); +#else + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(bsums, deltas)); +#endif + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d*q8.scale(iy, ibl)), _mm256_cvtepi32_ps(sumi), acc[iy]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, 0.125f*hsum_float_8(acc[iy])); + acc[iy] = _mm256_setzero_ps(); + } + } +} + template static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); @@ -9473,6 +9552,20 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.funcs[7] = mul_mat_q8_0_r8_q8_1<8>; expected_typeB = GGML_TYPE_Q8_1_X4; break; + case GGML_TYPE_IQ1_S: + mm.funcs[0] = mul_mat_iq1_s_q8_K<1>; + mm.funcs[1] = mul_mat_iq1_s_q8_K<2>; + mm.funcs[2] = mul_mat_iq1_s_q8_K<3>; + mm.funcs[3] = mul_mat_iq1_s_q8_K<4>; + mm.funcs[4] = mul_mat_iq1_s_q8_K<5>; + mm.funcs[5] = mul_mat_iq1_s_q8_K<6>; + mm.funcs[6] = mul_mat_iq1_s_q8_K<7>; + mm.funcs[7] = mul_mat_iq1_s_q8_K<8>; +#ifdef HAVE_FANCY_SIMD + mm.func16 = mul_mat_iq1_s_q8_K<16>; +#endif + expected_typeB = GGML_TYPE_Q8_K; + break; case GGML_TYPE_IQ1_S_R4: assert (ne00 % QK4_NL == 0); mm.funcs[0] = mul_mat_iq1_s_r4_q8_1<1>; @@ -12513,6 +12606,68 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI } } +template +static void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + Q8 q8(info); + int8x16_t qx[16]; + int32x4_t scales[2]; + int16x4_t deltas[2]; + float32x4_t acc[nrc_y] = {}; + auto delta_mask = vdupq_n_u16(0x8000); + for (int ix = 0; ix < nrc_x; ++ix) { + auto iq1s = (const block_iq1_s *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < n/QK_K; ++ibl) { + float d = GGML_FP16_TO_FP32(iq1s[ibl].d); + auto qhb = vld1q_u16(iq1s[ibl].qh); + auto scales128 = vandq_u16(vshrq_n_u16(qhb, 12), vdupq_n_u16(7)); + scales128 = vaddq_u16(vshlq_n_u16(scales128, 1), vdupq_n_u16(1)); + auto mask = vceqq_u16(vandq_u16(qhb, delta_mask), delta_mask); + // Note: we explicitely assume IQ1S_DELTA = 0.125 + auto deltas128 = vsubq_s16(vbicq_s16(scales128, mask), vandq_s16(scales128, mask)); + //auto deltas128 = vorrq_s16(vandq_s16(vdupq_n_s16(-1), mask), vbicq_s16(vdupq_n_s16(1), mask)); + //deltas128 = vmulq_s16(scales128, deltas128); + scales128 = vshlq_n_u16(scales128, 3); + auto qs = iq1s[ibl].qs; + auto qh = iq1s[ibl].qh; + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + qx[4*ib64+0] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[0] | ((qh[2*ib64+0] << 8) & 0x700)], iq1s_grid[qs[1] | ((qh[2*ib64+0] << 5) & 0x700)]}); + qx[4*ib64+1] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[2] | ((qh[2*ib64+0] << 2) & 0x700)], iq1s_grid[qs[3] | ((qh[2*ib64+0] >> 1) & 0x700)]}); + qx[4*ib64+2] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[4] | ((qh[2*ib64+1] << 8) & 0x700)], iq1s_grid[qs[5] | ((qh[2*ib64+1] << 5) & 0x700)]}); + qx[4*ib64+3] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[6] | ((qh[2*ib64+1] << 2) & 0x700)], iq1s_grid[qs[7] | ((qh[2*ib64+1] >> 1) & 0x700)]}); + qs += 8; + } + scales[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16 (scales128))); + scales[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales128))); + deltas[0] = vget_low_s16 (deltas128); + deltas[1] = vget_high_s16(deltas128); + for (int iy = 0; iy < nrc_y; ++iy) { + auto bsums = q8.load_bsums8(iy, ibl); + auto sumi = vdupq_n_s32(0); + sumi = vmlal_s16(sumi, deltas[0], vget_low_s16 (bsums)); + sumi = vmlal_s16(sumi, deltas[1], vget_high_s16(bsums)); + for (int k = 0; k < QK_K/128; ++k) { + auto qy = q8.load_quants_64(iy, ibl, 2*k+0); + auto dot1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+0], qy.val[0]), qx[8*k+1], qy.val[1]); + auto dot2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+2], qy.val[2]), qx[8*k+3], qy.val[3]); + auto dot12 = vpaddq_s32(dot1, dot2); + qy = q8.load_quants_64(iy, ibl, 2*k+1); + auto dot3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+4], qy.val[0]), qx[8*k+5], qy.val[1]); + auto dot4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+6], qy.val[2]), qx[8*k+7], qy.val[3]); + auto dot34 = vpaddq_s32(dot3, dot4); + auto dot = vpaddq_s32(dot12, dot34); + sumi = vmlaq_s32(sumi, dot, scales[k]); + } + acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d*q8.scale(iy, ibl)), vcvtq_f32_s32(sumi)); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, 0.125f*vaddvq_f32(acc[iy])); + acc[iy] = vdupq_n_f32(0); + } + } +} + template static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); @@ -14327,6 +14482,11 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { m.func16 = mul_mat_iq2_s_r4_q8_k<16>; expected_Btype = GGML_TYPE_Q8_K; break; + case GGML_TYPE_IQ1_S: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_s_q8_K); + m.func16 = mul_mat_iq1_s_q8_K<16>; + expected_Btype = GGML_TYPE_Q8_K; + break; case GGML_TYPE_IQ1_S_R4: SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_s_r4_q8_1); m.funcs[0] = mul_mat_iq1_s_r4_q8_1_1; From a45da7bfbf75503fe9e5a2f675db7825afdc6310 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 20 Feb 2025 13:55:13 +0200 Subject: [PATCH 013/132] Fix NEON gemm/gemv for legacy quants when row size is not divisible by 128 (#213) * Fix gemm/gemv for legacy quants when row size is not divisible by 128 * Fix typo --------- Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_mul_mat.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 33e0a4a7..e8150ec5 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -11310,9 +11310,9 @@ inline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& i q8.process_scales(i, deq, sc16, acc); sum_4(i, deq, q8, sc16, acc); } - //for (int i = 4*(nb/4); i < nb; ++i) { - // q8.process_1_block(i, deq, acc); - //} + for (int i = 4*(nb/4); i < nb; ++i) { + q8.process_1_block(i, deq, acc); + } for (int iy = 0; iy < Q8::nrc_y; ++iy) { info.store(ix, iy, vaddvq_f32(acc[iy])); @@ -11387,15 +11387,13 @@ static void mul_mat_qX_1_q8_1(int n, const void * vx, size_t bx, const DataInfo& Dequantizer deq1(vx, bx), deq2(vx, bx); mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); } else { - if (nrc_x%2 == 0) { + if (nrc_x%2 == 0 && n%128 == 0) { Dequantizer deq1(vx, bx), deq2(vx, bx); mul_mat_qX_Y_q8_Y_IK(n, deq1, deq2, q8, info, nrc_x); } else { Dequantizer deq(vx, bx); mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x); } - //Dequantizer deq(vx, bx); - //mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x); } } @@ -11406,7 +11404,7 @@ static void mul_mat_qX_0_q8_0(int n, const void * vx, size_t bx, const DataInfo& Dequantizer deq1(vx, bx), deq2(vx, bx); mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); } else { - if (nrc_x%2 == 0) { + if (nrc_x%2 == 0 && n%128 == 0) { Dequantizer deq1(vx, bx), deq2(vx, bx); mul_mat_qX_Y_q8_Y_IK(n, deq1, deq2, q8, info, nrc_x); } else { From 4b45b82e67d9362e7522e5c7107e9d99219e0432 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 20 Feb 2025 17:42:07 +0200 Subject: [PATCH 014/132] Honor attn_output specified in the command line also for low-bit quants --- src/llama.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index 0257a0a3..28e887ee 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -16339,7 +16339,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n ++qs.i_ffn_down; } else if (name.find("attn_output.weight") != std::string::npos) { - if (qs.model.hparams.n_expert >= 4) { + if (qs.params->attn_output_type < GGML_TYPE_COUNT) new_type = qs.params->attn_output_type; + else if (qs.model.hparams.n_expert >= 4) { new_type = GGML_TYPE_Q5_K; } else { if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_K; From b9a6639ac3bc77c64bba679cb85b14de0c4a9c9d Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 21 Feb 2025 15:33:25 +0200 Subject: [PATCH 015/132] Hopefully this really fixes the confusion between AVX512 and FANCY_SIMD (#216) Co-authored-by: Iwan Kawrakow --- ggml/src/CMakeLists.txt | 3 +- ggml/src/ggml.c | 4 +- ggml/src/iqk/iqk_config.h | 30 ++++++++++ ggml/src/iqk/iqk_mul_mat.cpp | 108 +++++++++++++--------------------- ggml/src/iqk/iqk_quantize.cpp | 10 +--- 5 files changed, 76 insertions(+), 79 deletions(-) create mode 100644 ggml/src/iqk/iqk_config.h diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 3d1a2970..0ed84956 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -254,6 +254,7 @@ if (GGML_BLAS) endif() set (GGML_SOURCES_IQK iqk/iqk_quantize.cpp) +set (GGML_HEADERS_IQK iqk/iqk_config.h) if (GGML_IQK_MUL_MAT) message(STATUS "Using optimized iqk matrix multiplications") add_compile_definitions(GGML_USE_IQK_MULMAT) @@ -1324,7 +1325,7 @@ add_library(ggml ${GGML_SOURCES_BLAS} ${GGML_HEADERS_BLAS} ${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE} ${GGML_SOURCES_IQK_MM} ${GGML_HEADERS_IQK_MM} - ${GGML_SOURCES_IQK} + ${GGML_SOURCES_IQK} ${GGML_HEADERS_IQK} ${GGML_SOURCES_CANN} ${GGML_HEADERS_CANN} ggml-aarch64.c ggml-aarch64.h ) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 0aee8dd4..ad092923 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -14,6 +14,7 @@ #include "iqk/iqk_quantize.h" #if GGML_USE_IQK_MULMAT #include "iqk/iqk_mul_mat.h" +#include "iqk/iqk_config.h" #endif #if defined(_MSC_VER) || defined(__MINGW32__) @@ -847,11 +848,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float_to_mat = quantize_mat_q8_0, .vec_dot = ggml_vec_dot_q8_0_q8_0, #if GGML_USE_IQK_MULMAT -#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) +#ifdef HAVE_FANCY_SIMD // Remember: we cannot add 128 to the Q8 quants and use iblock sum in Q8_1 to subtract as we do on Zen4 for pure AVX2 // because there the result of the _mm256_maddubs_epi16() instruction may overflow the int16_t range // (and it gets satured if it does), leading to wrong results. - // TODO: expose HAVE_FANCY_SIMD from iqk_mul_mat.cpp and use #ifdef HAVE_FANCY_SIMD instead of the above. .vec_dot_type = GGML_TYPE_Q8_1_X4, #else .vec_dot_type = GGML_TYPE_Q8_0_X4, diff --git a/ggml/src/iqk/iqk_config.h b/ggml/src/iqk/iqk_config.h new file mode 100644 index 00000000..fa4972b8 --- /dev/null +++ b/ggml/src/iqk/iqk_config.h @@ -0,0 +1,30 @@ +#pragma once + +#if defined IQK_IMPLEMENT +#undef IQK_IMPLEMENT +#endif + +#if defined __AVX2__ || defined __ARM_FEATURE_DOTPROD +#define IQK_IMPLEMENT +#endif + +#ifdef _MSC_VER +#define IQK_NOINLINE __declspec(noinline) +#define IQK_ALWAYS_INLINE inline +#if !defined __x86_64__ && defined _M_X64 +#define __x86_64__ +#endif +#else +#define IQK_NOINLINE __attribute__((__noinline__)) +#define IQK_ALWAYS_INLINE __attribute__((__always_inline__)) +#endif + +#if defined __x86_64__ +#if defined HAVE_FANCY_SIMD + #undef HAVE_FANCY_SIMD +#endif +#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) + #define HAVE_FANCY_SIMD +#endif +#endif + diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index e8150ec5..5750b952 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -7,20 +7,14 @@ // SPDX-License-Identifier: MIT // -#if defined IQK_IMPLEMENT -#undef IQK_IMPLEMENT -#endif +#include "iqk_config.h" -#if defined __AVX2__ || defined __ARM_FEATURE_DOTPROD -#define IQK_IMPLEMENT -#endif +#if defined IQK_IMPLEMENT #include #include #include -#if defined IQK_IMPLEMENT - #include "ggml-impl.h" #include "ggml-quants.h" #include "iqk_mul_mat.h" @@ -100,26 +94,6 @@ struct Perf { }; #endif -#ifdef _MSC_VER -#define IQK_NOINLINE __declspec(noinline) -#define IQK_ALWAYS_INLINE inline -#if !defined __x86_64__ && defined _M_X64 -#define __x86_64__ -#endif -#else -#define IQK_NOINLINE __attribute__((__noinline__)) -#define IQK_ALWAYS_INLINE __attribute__((__always_inline__)) -#endif - -#if defined __x86_64__ -#if defined HAVE_FANCY_SIMD - #undef HAVE_FANCY_SIMD -#endif -#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) - #define HAVE_FANCY_SIMD -#endif -#endif - namespace { typedef struct { @@ -1472,7 +1446,7 @@ inline void set_scales_16(const __m256i& all_scales, __m256i * scales) { template inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) { if (j == 0) { -#if defined(__AVX512VNNI__) && defined(__AVX512VL__) +#ifdef HAVE_FANCY_SIMD for (int iy = 0; iy < Q8::nrc_y; ++iy) { sumi[iy] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1))); @@ -1489,7 +1463,7 @@ inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, } #endif } else { -#if defined(__AVX512VNNI__) && defined(__AVX512VL__) +#ifdef HAVE_FANCY_SIMD for (int iy = 0; iy < Q8::nrc_y; ++iy) { sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4))); sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5))); @@ -2747,7 +2721,7 @@ struct DequantizerIQ6K final : public BaseDequantizer { auto h1 = _mm256_andnot_si256(mask4, hbits); auto mask2 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh1), mh1); auto mask3 = _mm256_cmpeq_epi8(_mm256_and_si256(h1, mh2), mh2); - auto mask1 = _mm256_andnot_si256(_mm256_or_si256(mask4, _mm256_or_si256(mask2, mask3)), _mm256_set1_epi8(0xff)); + auto mask1 = _mm256_andnot_si256(_mm256_or_si256(mask4, _mm256_or_si256(mask2, mask3)), _mm256_set1_epi8(-1)); // 0xff; return _mm256_or_si256(_mm256_or_si256(_mm256_and_si256(mask1, _mm256_shuffle_epi8(values[0], l)), _mm256_and_si256(mask2, _mm256_shuffle_epi8(values[1], l))), _mm256_or_si256(_mm256_and_si256(mask3, _mm256_shuffle_epi8(values[2], l)), @@ -2843,7 +2817,7 @@ struct DequantizerIQ4KSS final : public BaseDequantizer { const __m256i values; __m256i data[4]; const __m256i smask = _mm256_set_epi64x(0x0080004000200010, 0x0008000400020001, 0x0080004000200010, 0x0008000400020001); - const __m256i bmask = _mm256_set1_epi16(0xfffe); + const __m256i bmask = _mm256_set1_epi16(-2); // 0xfffe; const __m128i mask = _mm_set1_epi16(254); const __m128i m127 = _mm_set1_epi16(-127); const __m128i m128 = _mm_set1_epi16(-128); @@ -7049,7 +7023,7 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI template inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) { if (j == 0) { -#if defined(__AVX512VNNI__) && defined(__AVX512VL__) +#ifdef HAVE_FANCY_SIMD auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]); auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]); auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]); @@ -7065,7 +7039,7 @@ inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, cons sumi[1] = _mm256_add_epi32(p2, p4); #endif } else { -#if defined(__AVX512VNNI__) && defined(__AVX512VL__) +#ifdef HAVE_FANCY_SIMD auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]); auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]); auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]); @@ -7282,7 +7256,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const __m256i accd[nrc_y]; __m256i val[4]; -#if !(defined __AVX512VNNI__ && defined __AVX512VL__) +#ifndef HAVE_FANCY_SIMD const auto m1_16 = _mm256_set1_epi16(1); #endif @@ -7304,7 +7278,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const for (int i = 0; i < nb/2; ++i) { deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]); deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]); -#if defined __AVX512VNNI__ && defined __AVX512VL__ +#ifdef HAVE_FANCY_SIMD acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, val[0], q8.load_quants(0, i, 0)), val[1], q8.load_quants(0, i, 1)); acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, val[2], q8.load_quants(0, i, 2)), val[3], q8.load_quants(0, i, 3)); #else @@ -7328,7 +7302,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]); for (int iy = 0; iy < nrc_y; ++iy) { -#if defined __AVX512VNNI__ && defined __AVX512VL__ +#ifdef HAVE_FANCY_SIMD accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i, 0)), val[1], q8.load_quants(iy, i, 1)), @@ -7349,7 +7323,7 @@ IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const if (i < nb) { deq.prepare_iq1bn_quants(x + i, val[0], val[1]); for (int iy = 0; iy < nrc_y; ++iy) { -#if defined __AVX512VNNI__ && defined __AVX512VL__ +#ifdef HAVE_FANCY_SIMD accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i/2, 0)), val[1], q8.load_quants(iy, i/2, 1)); #else @@ -7401,7 +7375,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const __m256i accd[nrc_y]; __m256i val[4]; -#if !(defined __AVX512VNNI__ && defined __AVX512VL__) +#ifndef HAVE_FANCY_SIMD const auto m1_16 = _mm256_set1_epi16(1); #endif @@ -7413,7 +7387,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const __m256i acc[2] = {}; for (int i = 0; i < nb/2; ++i) { deq.prepare4(i, val); -#if defined __AVX512VNNI__ && defined __AVX512VL__ +#ifdef HAVE_FANCY_SIMD acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], val[0], q8.load_quants(0, i, 0)), val[1], q8.load_quants(0, i, 1)); acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], val[2], q8.load_quants(0, i, 2)), @@ -7436,7 +7410,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const for (int i = 0; i < nb/2; ++i) { deq.prepare4(i, val); for (int iy = 0; iy < nrc_y; ++iy) { -#if defined __AVX512VNNI__ && defined __AVX512VL__ +#ifdef HAVE_FANCY_SIMD accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i, 0)), val[1], q8.load_quants(iy, i, 1)), val[2], q8.load_quants(iy, i, 2)), val[3], q8.load_quants(iy, i, 3)); @@ -7455,7 +7429,7 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const if (i < nb) { deq.prepare2(i, val); for (int iy = 0; iy < nrc_y; ++iy) { -#if defined __AVX512VNNI__ && defined __AVX512VL__ +#ifdef HAVE_FANCY_SIMD accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i/2, 0)), val[1], q8.load_quants(iy, i/2, 1)); #else @@ -8537,7 +8511,7 @@ template struct QFT final : public QFBase { xv[1] = load1(ix+1, i); xv[2] = load1(ix+2, i); xv[3] = load1(ix+3, i); -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512F__ auto t0 = _mm512_unpacklo_ps(xv[0], xv[1]); auto t1 = _mm512_unpacklo_ps(xv[2], xv[3]); auto t2 = _mm512_unpackhi_ps(xv[0], xv[1]); @@ -14749,7 +14723,7 @@ struct BaseHelper { }; struct F16 { -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512F__ using Data = __m512; constexpr static int block_size = 16; constexpr static int num_registers = 32; @@ -14910,7 +14884,7 @@ struct HelperQ8KV final : public BaseHelper { v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1]))); #else auto vd = F16::set1(q8->d); -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512F__ v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+0)))); v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)q8->qs+i+1)))); #else @@ -14945,7 +14919,7 @@ struct HelperQ80 final : public BaseHelper { v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1]))); #else auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512F__ v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+0)))); v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+1)))); #else @@ -15215,7 +15189,7 @@ struct HelperQ40 final : public BaseHelper { #else auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); auto q = _mm_loadu_si128((const __m128i *)dl->qs); -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512F__ auto ql = _mm_add_epi8(_mm_and_si128(q, mask), m8); auto qh = _mm_add_epi8(_mm_and_si128(_mm_srli_epi16(q, 4), mask), m8); v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql))); @@ -15260,7 +15234,7 @@ struct HelperQ41 final : public BaseHelper { auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); auto vm = F16::set1(GGML_FP16_TO_FP32(dl->m)); auto q = _mm_loadu_si128((const __m128i *)dl->qs); -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512F__ auto ql = _mm_and_si128(q, mask); auto qh = _mm_and_si128(_mm_srli_epi16(q, 4), mask); v1 = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)), vm); @@ -15306,7 +15280,7 @@ struct HelperIQ4nl final : public BaseHelper { #else auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); auto q = _mm_loadu_si128((const __m128i *)dl->qs); -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512F__ auto ql = _mm_shuffle_epi8(values, _mm_and_si128(q, mask)); auto qh = _mm_shuffle_epi8(values, _mm_and_si128(_mm_srli_epi16(q, 4), mask)); v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql))); @@ -15361,7 +15335,7 @@ struct HelperQ60 final : public BaseHelper { auto bl = _mm_loadu_si128((const __m128i *)dl->qs); uint64_t aux64; std::memcpy(&aux64, dl->qh, 8); auto bh = _mm_set_epi64x(aux64, aux64 << 4); -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX512F__ auto ql = _mm_add_epi8(_mm_or_si128(_mm_and_si128(bl, mask_l), _mm_and_si128(bh, mask_h)), m32); auto qh = _mm_add_epi8(_mm_or_si128(_mm_and_si128(_mm_srli_epi16(bl, 4), mask_l), _mm_and_si128(_mm_srli_epi16(bh, 2), mask_h)), m32); v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql))); @@ -15537,6 +15511,22 @@ struct FlashMS { } return F16::reduce_max(vk); } + static inline __m256 apply_mask(int l, const char * mask, __m256 val, __m256 vinf) { + auto m128 = _mm_loadu_si128((const __m128i *)mask+l); + m128 = _mm_cmpeq_epi16(m128, _mm_setzero_si128()); + auto m256 = _mm256_cvtepi16_epi32(m128); + auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16))); + return _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf)); + } +#ifdef __AVX512F__ + static inline __m512 apply_mask(int l, const char * mask, __m512 val, __m512 vinf) { + auto m256 = _mm256_loadu_si256((const __m256i *)mask+l); + m256 = _mm256_cmpeq_epi16(m256, _mm256_setzero_si256()); + auto m512 = _mm512_cvtepi16_epi32(m256); + auto mf = _mm512_castsi512_ps(_mm512_or_si512(m512, _mm512_slli_epi32(m512, 16))); + return _mm512_or_ps(_mm512_and_ps(mf, val), _mm512_andnot_ps(mf, vinf)); + } +#endif inline float load_apply_mask_and_scale(int j, F16::Data * vk, const char * mask) { #ifdef HAVE_FANCY_SIMD auto vzero = _mm256_set1_epi16(0); @@ -15554,15 +15544,9 @@ struct FlashMS { } } #else - auto vzero = _mm_set1_epi16(0); auto vinf = F16::set1(-INFINITY); for (int l = 0; l < k_step/F16::block_size; ++l) { - auto m128 = _mm_loadu_si128((const __m128i *)mask + l); - m128 = _mm_cmpeq_epi16(m128, vzero); - auto m256 = _mm256_cvtepi16_epi32(m128); - auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16))); - auto val = _mm256_loadu_ps(cache + k_step*j + F16::block_size*l); - vk[l] = _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf)); + vk[l] = apply_mask(l, mask, F16::load(cache + k_step*j + F16::block_size*l), vinf); } if (softcap <= 0) { for (int l = 0; l < k_step/F16::block_size; ++l) vk[l] = F16::mul(vscale, vk[l]); @@ -15630,14 +15614,12 @@ struct FlashQKV { for (int i = 0; i < D/F16::block_size; ++i) vq[i] = F16::mul(vms, vq[i]); } } - //F16::Data v[8]; F16::Data v0, v1; for (int l = 0; l < k_step; l += 4) { auto vs0 = F16::set1(fms.cache[l + 0]); auto vs1 = F16::set1(fms.cache[l + 1]); auto vs2 = F16::set1(fms.cache[l + 2]); auto vs3 = F16::set1(fms.cache[l + 3]); - //auto vs = F16::set4(fms.cache + l); for (int i = 0; i < D/F16::block_size; i += 2) { vh.load(l+0, i, v0, v1); vq[i+0] = F16::fmadd(vq[i+0], v0, vs0); @@ -15651,14 +15633,6 @@ struct FlashQKV { vh.load(l+3, i, v0, v1); vq[i+0] = F16::fmadd(vq[i+0], v0, vs3); vq[i+1] = F16::fmadd(vq[i+1], v1, vs3); - //vq[i+0] = F16::fmadd_lane0(vq[i+0], v[0], vs); - //vq[i+1] = F16::fmadd_lane0(vq[i+1], v[4], vs); - //vq[i+0] = F16::fmadd_lane1(vq[i+0], v[1], vs); - //vq[i+1] = F16::fmadd_lane1(vq[i+1], v[5], vs); - //vq[i+0] = F16::fmadd_lane2(vq[i+0], v[2], vs); - //vq[i+1] = F16::fmadd_lane2(vq[i+1], v[6], vs); - //vq[i+0] = F16::fmadd_lane3(vq[i+0], v[3], vs); - //vq[i+1] = F16::fmadd_lane3(vq[i+1], v[7], vs); } } for (int i = 0; i < D/F16::block_size; ++i) F16::store(qkv_cache + F16::block_size*i, vq[i]); diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 7b777a1f..b61ae2db 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -12,6 +12,7 @@ #define GGML_COMMON_IMPL_C #include "ggml-common.h" #include "iqk_quantize.h" +#include "iqk_config.h" #include #include @@ -43,15 +44,6 @@ constexpr int popcount(uint32_t x) { return __builtin_popcount(x); } constexpr int popcount(uint64_t x) { return __builtin_popcountll(x); } #endif -#if defined __x86_64__ -#if defined HAVE_FANCY_SIMD - #undef HAVE_FANCY_SIMD -#endif -#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) - #define HAVE_FANCY_SIMD -#endif -#endif - namespace { inline int nearest_int(float fval) { From c4a5103299e44adc8692e3e373c1974fa9fee270 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 22 Feb 2025 09:38:51 +0200 Subject: [PATCH 016/132] Better strategy for attention matrix multiplications when generating tokens (#218) * This seems to be a better way to do the attention matrix multiplications in the TG case. * Cleanup --------- Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_mul_mat.cpp | 58 ++++++++++++++++++++++++++++++++++-- 1 file changed, 56 insertions(+), 2 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 5750b952..fffddada 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -352,6 +352,26 @@ bool iqk_mul_mat_4d(long Nx, long Ny, long ne00, } return true; } + + if (ne13 == 1 && ne12 > 1 && ne12 == ne02 && Ny == 1 && nb02 < strideA) { + //printf("TG attention gemm for %d heads and Nx = %d\n", (int)ne02, (int)Nx); + MulMat mm; + if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) { + return false; + } + int n_per_thread = (Nx + nth - 1)/nth; + int first = ith*n_per_thread; + if (first >= Nx) return true; + int last = first + n_per_thread <= Nx ? first + n_per_thread : Nx; + for (int ix = first; ix < last; ++ix) { + for (int i02 = 0; i02 < ne02; ++i02) { + DataInfo info{C + ix + i02*nb2, (const char *)B + i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0}; + mm.funcs[0](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), nb02, info, 1); + } + } + return true; + } + int gcd = simple_gcd(ne12*ne13, nth); int counter = 0; for (int64_t i13 = 0; i13 < ne13; i13++) { @@ -6229,8 +6249,8 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn // The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__) template static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%8 == 0); GGML_ASSERT(n%32 == 0); + GGML_ASSERT(nrc_x%8 == 0); #ifndef HAVE_FANCY_SIMD auto m1 = _mm256_set1_epi16(1); #endif @@ -6298,8 +6318,42 @@ static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const Data template static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%8 == 0); GGML_ASSERT(n%32 == 0); + if (nrc_y == 1 && nrc_x == 1) { + auto dx = (const float *)vx; + auto dy = (const float *)info.src1_row(0); +#ifdef HAVE_FANCY_SIMD + auto sy = (const int32_t *)(dy + 1); + auto x = (const int8_t *)(dx + 2); + auto y = (const int8_t *)(dy + 2); + auto isum = _mm512_setzero_si512(); + for (int i = 0; i < n/64; ++i) { + auto qx = _mm512_loadu_si512((const __m512i *)x + i); + auto qy = _mm512_loadu_si512((const __m512i *)y + i); + isum = _mm512_dpbusd_epi32(isum, _mm512_add_epi8(qx, _mm512_set1_epi8(127)), qy); + } + auto isum256 = _mm256_add_epi32(_mm512_castsi512_si256(isum), _mm512_extracti32x8_epi32(isum, 1)); + for (int i = 2*(n/64); i < n/32; ++i) { + auto qx = _mm256_loadu_si256((const __m256i *)x + i); + auto qy = _mm256_loadu_si256((const __m256i *)y + i); + isum256 = _mm256_dpbusd_epi32(isum256, _mm256_add_epi8(qx, _mm256_set1_epi8(127)), qy); + } + info.store(0, 0, dx[0]*dy[0]*(hsum_i32_8(isum256) - 127*sy[0])); +#else + auto x = (const int8_t *)(dx + 2); + auto y = (const int8_t *)(dy + 2); + auto isum = _mm256_setzero_si256(); + for (int i = 0; i < n/32; ++i) { + auto qx = _mm256_loadu_si256((const __m256i *)x + i); + auto qy = _mm256_loadu_si256((const __m256i *)y + i); + auto dot = _mm256_maddubs_epi16(_mm256_sign_epi8(qx, qx), _mm256_sign_epi8(qy, qx)); + isum = _mm256_add_epi32(isum, _mm256_madd_epi16(_mm256_set1_epi16(1), dot)); + } + info.store(0, 0, dx[0]*dy[0]*hsum_i32_8(isum)); +#endif + return; + } + GGML_ASSERT(nrc_x%8 == 0); __m256i qx[2]; __m256i acc[2*nrc_y] = {}; float dy[nrc_y]; From 33646fc40949e0fdcf16d96f5b40d12bf93244a9 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 22 Feb 2025 09:41:40 +0200 Subject: [PATCH 017/132] Fuse MoE up and gate matrix multiplications (#219) * This seems to be a better way to do the attention matrix multiplications in the TG case. * Cleanup * Fuse up and gate gemms in MoE models Small (~1-2%) but measurable performan ce gain --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml.c | 161 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 158 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index ad092923..eb39d574 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -14581,6 +14581,149 @@ IQK_MulMat_Not_Available:; #undef MMID_MATRIX_ROW } +#if GGML_USE_IQK_MULMAT +static void ggml_compute_forward_mul_mat_id_up_gate( + const struct ggml_compute_params * params, + struct ggml_tensor * dst1, + struct ggml_tensor * dst2) { + + GGML_ASSERT(dst1->src[1] == dst2->src[1]); + GGML_ASSERT(dst1->src[2] == dst2->src[2]); + GGML_ASSERT(dst1->src[0]->type == dst2->src[0]->type); + GGML_ASSERT(dst1->type == GGML_TYPE_F32 && dst2->type == GGML_TYPE_F32); + + const struct ggml_tensor * src1 = dst1->src[1]; + const struct ggml_tensor * ids = dst1->src[2]; + const struct ggml_tensor * src0_1 = dst1->src[0]; + const struct ggml_tensor * src0_2 = dst2->src[0]; + const struct ggml_tensor * src0 = src0_1; + const struct ggml_tensor * dst = dst1; // so GGML_TENSOR_BINARY_OP_LOCALS works + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const enum ggml_type type = src0->type; + + enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + GGML_ASSERT(ne13 == 1); + + // row groups + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_expert + + char * wdata_src1_end = (src1->type == vec_dot_type) ? + (char *) params->wdata : + (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, src1->ne[0])*ggml_nrows(src1), sizeof(int64_t)); + + struct mmid_row_mapping { + int32_t i1; + int32_t i2; + }; + + int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] + struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11] + + if (src1->type != vec_dot_type) { + + ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float; + + char * wdata = params->wdata; + + const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); + const size_t nbw2 = nbw1*ne11; + const size_t nbw3 = nbw2*ne12; + + assert(params->wsize >= ne13*nbw3); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = ith; i11 < ne11; i11 += nth) { + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), + ne10); + } + } + } + } + +#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)] + + if (ith == 0) { + // initialize matrix_row_counts + memset(matrix_row_counts, 0, n_as*sizeof(int64_t)); + + // group rows by src0 matrix + for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { + for (int id = 0; id < n_ids; ++id) { + const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]); + + assert(i02 >= 0 && i02 < n_as); + + MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1}; + matrix_row_counts[i02] += 1; + } + } + } + + ggml_barrier(params->shared); + + // compute each matrix multiplication in sequence + for (int cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + + if (cne1 == 0) { + continue; + } + + const char * src0_1_cur = (const char *) src0_1->data + cur_a*nb02; + const char * src0_2_cur = (const char *) src0_2->data + cur_a*nb02; + + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t row_size = ggml_row_size(vec_dot_type, ne10); + + const int64_t nr0 = ne01; // src0 rows + const int64_t nr1 = cne1; // src1 rows + + if (nth%2 == 0) { + const char * src0_d = ith%2 == 0 ? src0_1_cur : src0_2_cur; + void * dst_d = ith%2 == 0 ? dst1->data : dst2->data; + if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, + type, src0_d, nb01, + vec_dot_type, (const char *)wdata, row_size, + (float *)dst_d, nb1, nb2, + matrix_rows + cur_a*ne12, ith/2, nth/2)) GGML_ABORT("fatal error"); + + } else { + if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, + src0_1->type, (const char *)src0_1_cur, nb01, + vec_dot_type, (const char *)wdata, row_size, + (float *)dst1->data, nb1, nb2, + matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); + if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, + src0_2->type, (const char *)src0_2_cur, nb01, + vec_dot_type, (const char *)wdata, row_size, + (float *)dst2->data, nb1, nb2, + matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); + } + } + +#undef MMID_MATRIX_ROW +} +#endif + // ggml_compute_forward_out_prod static void ggml_compute_forward_out_prod_f32( @@ -19007,17 +19150,18 @@ static void ggml_compute_forward_cross_entropy_loss_back( ///////////////////////////////// -static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { +static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_tensor * next) { GGML_ASSERT(params); if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) { - return; + return false; } #if IK_PRINT_TIMING int64_t t1 = ggml_time_us(); #endif + bool skip_next = false; switch (tensor->op) { case GGML_OP_DUP: { @@ -19125,6 +19269,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_MUL_MAT_ID: { +#if GGML_USE_IQK_MULMAT + if (next && next->op == GGML_OP_MUL_MAT_ID && tensor->src[1] == next->src[1] && + tensor->src[0]->type == next->src[0]->type) { + ggml_compute_forward_mul_mat_id_up_gate(params, tensor, next); + skip_next = true; + break; + } +#endif ggml_compute_forward_mul_mat_id(params, tensor); } break; case GGML_OP_OUT_PROD: @@ -19367,6 +19519,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm int64_t t2 = ggml_time_us(); if (params->ith == 0) printf("%s(%s): %d us\n", ggml_op_name(tensor->op), tensor->name, (int)(t2 - t1)); #endif + return skip_next; } //////////////////////////////////////////////////////////////////////////////// @@ -21219,7 +21372,9 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { if (ggml_is_noop(node)) continue; - ggml_compute_forward(¶ms, node); + if (ggml_compute_forward(¶ms, node, node_n < cgraph->n_nodes-1 ? cgraph->nodes[node_n+1] : NULL)) { + ++node_n; + } if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { state->shared->ec = GGML_STATUS_ABORTED; From 49261058442cfe382dab3270fcd86652296a75c0 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 22 Feb 2025 14:25:38 +0200 Subject: [PATCH 018/132] Fix #217 (#220) * Fix #217 * Remove stuff commited by mistake --------- Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_mul_mat.cpp | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index fffddada..2c4987a0 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -15841,23 +15841,18 @@ struct FlashQKfp32 { #endif constexpr int qrem = q_step - nrc_q*(q_step/nrc_q); constexpr int krem = k_step - nrc_k*(k_step/nrc_k); + static_assert(krem == 0); DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; for (int iq = 0; iq < q_step/nrc_q; ++iq) { for (int ik = 0; ik < k_step/nrc_k; ++ik) { mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); } - if constexpr (krem > 0) { - mul_mat_Qx_Qy_MxN_fa, QFT>(D, kh.block, kh.stride, k_step - krem, info); - } info.cur_y += nrc_q; } if constexpr (qrem > 0) { for (int ik = 0; ik < k_step/nrc_k; ++ik) { mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); } - if constexpr (krem > 0) { - mul_mat_Qx_Qy_MxN_fa, QFT>(D, kh.block, kh.stride, k_step - krem, info); - } } F16::Data vk[k_step/F16::block_size]; for (int j = 0; j < q_step; ++j) { @@ -15910,7 +15905,7 @@ struct FlashQKfp32 { constexpr int nrc_k = 8; #endif static_assert(k_step%nrc_k == 0); - int qrem = q_step - nrc_q*(q_step/nrc_q); + int qrem = nq - nrc_q*(nq/nrc_q); DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; for (int iq = 0; iq < nq/nrc_q; ++iq) { for (int ik = 0; ik < k_step/nrc_k; ++ik) { @@ -15960,7 +15955,7 @@ struct FlashQKfp32 { } } F16::Data vk[k_step/F16::block_size]; - for (int j = 0; j < q_step; ++j) { + for (int j = 0; j < nq; ++j) { fms.update_M_S(j, vk, mask + stride_m*j); } } From 71b7b510c2dc55ae70934d246cd4e6c3bdf4a95c Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 23 Feb 2025 08:02:16 +0200 Subject: [PATCH 019/132] Fix compilation error with IQK_FA_ALL_QUANTS enabled (#226) Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_mul_mat.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 2c4987a0..ed7309cd 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -15315,9 +15315,11 @@ struct HelperIQ4nl final : public BaseHelper { #ifdef __aarch64__ using block_q8 = block_q8_0; HelperIQ4nl(const char * data, int stride) : Base(data, stride), values(vld1q_s8(iq4k_values)) {} + constexpr static int block_size_q = QK8_0; #else HelperIQ4nl(const char * data, int stride) : Base(data, stride) {} using block_q8 = block_q8_1; + constexpr static int block_size_q = QK8_1; #endif // Needed for v * softmax(k * q) From 46bf73a37f1aabe6f0b40365b0c7b2ba831905f5 Mon Sep 17 00:00:00 2001 From: saood06 Date: Sun, 23 Feb 2025 00:16:27 -0600 Subject: [PATCH 020/132] Add new sweep-bench benchmark (#225) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * examples : add new sweep-bench benchmark * Change documentation to reference ik_llama.cpp * Made it compile with ik_llama * Fix JSONL output --------- Co-authored-by: Stanisław Szymczyk --- common/common.cpp | 9 ++ common/common.h | 2 + examples/CMakeLists.txt | 1 + examples/sweep-bench/CMakeLists.txt | 5 + examples/sweep-bench/README.md | 64 ++++++++ examples/sweep-bench/sweep-bench-plot.py | 100 ++++++++++++ examples/sweep-bench/sweep-bench.cpp | 189 +++++++++++++++++++++++ 7 files changed, 370 insertions(+) create mode 100644 examples/sweep-bench/CMakeLists.txt create mode 100644 examples/sweep-bench/README.md create mode 100755 examples/sweep-bench/sweep-bench-plot.py create mode 100644 examples/sweep-bench/sweep-bench.cpp diff --git a/common/common.cpp b/common/common.cpp index f7a6f76f..6bf6e4f9 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1360,6 +1360,15 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.warmup = false; return true; } + if (arg == "--output-format") { + CHECK_ARG + std::string value(argv[i]); + /**/ if (value == "jsonl") { params.sweep_bench_output_jsonl = true; } + else if (value == "md") { params.sweep_bench_output_jsonl = false; } + else { invalid_param = true; } + return true; + } + #ifndef LOG_DISABLE_LOGS // Parse args for logging parameters if (log_param_single_parse(argv[i])) { diff --git a/common/common.h b/common/common.h index fc1ae619..b5b67986 100644 --- a/common/common.h +++ b/common/common.h @@ -269,6 +269,8 @@ struct gpt_params { bool spm_infill = false; // suffix/prefix/middle pattern for infill std::string lora_outfile = "ggml-lora-merged-f16.gguf"; + + bool sweep_bench_output_jsonl = false; }; void gpt_params_handle_hf_token(gpt_params & params); diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 67b3d277..3987fe13 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -51,5 +51,6 @@ else() add_subdirectory(save-load-state) add_subdirectory(simple) add_subdirectory(speculative) + add_subdirectory(sweep-bench) add_subdirectory(tokenize) endif() diff --git a/examples/sweep-bench/CMakeLists.txt b/examples/sweep-bench/CMakeLists.txt new file mode 100644 index 00000000..e49f0fea --- /dev/null +++ b/examples/sweep-bench/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-sweep-bench) +add_executable(${TARGET} sweep-bench.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/sweep-bench/README.md b/examples/sweep-bench/README.md new file mode 100644 index 00000000..608fd104 --- /dev/null +++ b/examples/sweep-bench/README.md @@ -0,0 +1,64 @@ +# ik_llama.cpp/example/sweep-bench + +Benchmark the prompt processing and token generation performance of `ik_llama.cpp` +by doing a sweep over a whole context size and gathering performance metrics +in each ubatch-sized window. Only a single token sequence is used. + +The benchmark steps are: + +for each ubatch-sized window in context: + 1. generate ubatch/4 tokens (not the whole window to save some time) + 2. measure generation performance + 3. remove generated tokens from KV cache + 4. prepare a ubatch-sized batch of random tokens + 4. process prepated batch + 5. measure prompt processing performance + +The purpose of the benchmark is to visualize how the performance changes with +the context size without averaging the metrics values over the whole context. + +## Usage + +./llama-sweep-bench -c 8704 -ub 512 -m models/Meta-Llama-3.2-3B-Instruct-Q8_0.gguf + +## Sample results + +- `PP` - prompt tokens per ubatch +- `TG` - generated tokens per ubatch +- `N_KV` - current KV cache size +- `T_PP` - prompt processing time (i.e. time to first token) +- `S_PP` - prompt processing speed (`(B*PP)/T_PP` or `PP/T_PP`) +- `T_TG` - time to generate all batches +- `S_TG` - text generation speed (`(B*TG)/T_TG`) + +| PP | TG | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | +|-------|--------|--------|----------|----------|----------|----------| +| 512 | 128 | 0 | 1.100 | 465.51 | 2.311 | 55.38 | +| 512 | 128 | 512 | 1.183 | 432.97 | 1.895 | 67.55 | +| 512 | 128 | 1024 | 1.305 | 392.38 | 2.071 | 61.81 | +| 512 | 128 | 1536 | 1.279 | 400.42 | 2.164 | 59.14 | +| 512 | 128 | 2048 | 1.571 | 325.96 | 2.280 | 56.14 | +| 512 | 128 | 2560 | 1.431 | 357.87 | 2.418 | 52.94 | +| 512 | 128 | 3072 | 1.515 | 337.93 | 2.566 | 49.88 | +| 512 | 128 | 3584 | 1.588 | 322.34 | 2.722 | 47.03 | +| 512 | 128 | 4096 | 1.675 | 305.70 | 2.864 | 44.69 | +| 512 | 128 | 4608 | 1.769 | 289.50 | 2.999 | 42.68 | +| 512 | 128 | 5120 | 1.845 | 277.48 | 3.102 | 41.26 | +| 512 | 128 | 5632 | 1.893 | 270.46 | 3.219 | 39.76 | +| 512 | 128 | 6144 | 1.953 | 262.20 | 3.348 | 38.23 | +| 512 | 128 | 6656 | 2.018 | 253.71 | 3.474 | 36.84 | +| 512 | 128 | 7168 | 2.078 | 246.34 | 3.589 | 35.66 | +| 512 | 128 | 7680 | 2.140 | 239.22 | 3.717 | 34.43 | +| 512 | 128 | 8192 | 2.196 | 233.15 | 3.854 | 33.21 | + +### JSONL output + +Pass `--output-format jsonl` to output JSONL instead of Markdown, á la + +```json lines +{"n_kv_max": 8704, "n_batch": 2048, "n_ubatch": 512, "flash_attn": 0, "n_gpu_layers": -1, "n_threads": 32, "n_threads_batch": 32, "pp": 512, "tg": 128, "n_kv": 0, "t_pp": 1.093814, "speed_pp": 468.086884, "t_tg": 1.780312, "speed_tg": 71.897514 } +{"n_kv_max": 8704, "n_batch": 2048, "n_ubatch": 512, "flash_attn": 0, "n_gpu_layers": -1, "n_threads": 32, "n_threads_batch": 32, "pp": 512, "tg": 128, "n_kv": 512, "t_pp": 1.169302, "speed_pp": 437.868073, "t_tg": 1.897474, "speed_tg": 67.458099 } +{"n_kv_max": 8704, "n_batch": 2048, "n_ubatch": 512, "flash_attn": 0, "n_gpu_layers": -1, "n_threads": 32, "n_threads_batch": 32, "pp": 512, "tg": 128, "n_kv": 1024, "t_pp": 1.183700, "speed_pp": 432.542053, "t_tg": 2.059179, "speed_tg": 62.160694 } +{"n_kv_max": 8704, "n_batch": 2048, "n_ubatch": 512, "flash_attn": 0, "n_gpu_layers": -1, "n_threads": 32, "n_threads_batch": 32, "pp": 512, "tg": 128, "n_kv": 1536, "t_pp": 1.428625, "speed_pp": 358.386566, "t_tg": 2.160639, "speed_tg": 59.241734 } +{"n_kv_max": 8704, "n_batch": 2048, "n_ubatch": 512, "flash_attn": 0, "n_gpu_layers": -1, "n_threads": 32, "n_threads_batch": 32, "pp": 512, "tg": 128, "n_kv": 2048, "t_pp": 1.360647, "speed_pp": 376.291595, "t_tg": 2.274003, "speed_tg": 56.288403 } +``` diff --git a/examples/sweep-bench/sweep-bench-plot.py b/examples/sweep-bench/sweep-bench-plot.py new file mode 100755 index 00000000..0d3c81db --- /dev/null +++ b/examples/sweep-bench/sweep-bench-plot.py @@ -0,0 +1,100 @@ +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('file', nargs='+') +args = parser.parse_args() + +df = None + +for jsonl_file in args.file: + # Read JSONL file into DataFrame + df_part = pd.read_json(jsonl_file, lines=True) + df_part['label'] = jsonl_file + if df is None: + df = df_part + else: + df = pd.concat([df, df_part]) + +# Group by model and n_kv, calculate mean and std for both speed metrics +df_grouped = df.groupby(['label', 'n_kv']).agg({ + 'speed_pp': ['mean', 'std'], + 'speed_tg': ['mean', 'std'] +}).reset_index() + +# Flatten multi-index columns +df_grouped.columns = ['label', 'n_kv', 'speed_pp_mean', 'speed_pp_std', + 'speed_tg_mean', 'speed_tg_std'] + +# Replace NaN with 0 (std for a single sample is NaN) + +df_grouped['speed_pp_std'] = df_grouped['speed_pp_std'].fillna(0) +df_grouped['speed_tg_std'] = df_grouped['speed_tg_std'].fillna(0) + +# Prepare ticks values for X axis (prune for readability) +x_ticks = df['n_kv'].unique() +while len(x_ticks) > 16: + x_ticks = x_ticks[::2] + +# Get unique labels and color map +labels = df_grouped['label'].unique() +colors = plt.cm.rainbow(np.linspace(0, 1, len(labels))) + +# Create prompt processing plot +plt.figure(figsize=(10, 6)) +ax1 = plt.gca() + +plt.grid() + +ax1.set_xticks(x_ticks) + +# Plot each label's data +for label, color in zip(labels, colors): + label_data = df_grouped[df_grouped['label'] == label].sort_values('n_kv') + + # Plot prompt processing + pp = ax1.errorbar(label_data['n_kv'], label_data['speed_pp_mean'], + yerr=label_data['speed_pp_std'], color=color, + marker='o', linestyle='-', label=label) + +# Add labels and title +ax1.set_xlabel('Context Length (tokens)') +ax1.set_ylabel('Prompt Processing Rate (t/s)') +plt.title('Prompt Processing Performance Comparison') + +ax1.legend(loc='upper right') + +# Adjust layout and save +plt.tight_layout() +plt.savefig('performance_comparison_pp.png', bbox_inches='tight') +plt.close() + +# Create token generation plot +plt.figure(figsize=(10, 6)) +ax1 = plt.gca() + +plt.grid() +ax1.set_xticks(x_ticks) + +# Plot each model's data +for label, color in zip(labels, colors): + label_data = df_grouped[df_grouped['label'] == label].sort_values('n_kv') + + # Plot token generation + tg = ax1.errorbar(label_data['n_kv'], label_data['speed_tg_mean'], + yerr=label_data['speed_tg_std'], color=color, + marker='s', linestyle='-', label=label) + +# Add labels and title +ax1.set_xlabel('Context Length (n_kv)') +ax1.set_ylabel('Token Generation Rate (t/s)') +plt.title('Token Generation Performance Comparison') + +ax1.legend(loc='upper right') + +# Adjust layout and save +plt.tight_layout() +plt.savefig('performance_comparison_tg.png', bbox_inches='tight') +plt.close() diff --git a/examples/sweep-bench/sweep-bench.cpp b/examples/sweep-bench/sweep-bench.cpp new file mode 100644 index 00000000..4e594de5 --- /dev/null +++ b/examples/sweep-bench/sweep-bench.cpp @@ -0,0 +1,189 @@ +#include "ggml.h" +#include "llama.h" +#include "common.h" +#include "llama-vocab.h" + +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +# define NOMINMAX +#endif +#include +#endif + +#include +#include +#include +#include +#include + +static void print_usage(int, char ** argv) { + LOG("\nexample usage:\n"); + LOG("\n %s -m model.gguf -c 8192 -b 2048 -ub 512\n", argv[0]); + LOG("\n"); +} + +int main(int argc, char ** argv) { + + gpt_params params; + + if (!gpt_params_parse(argc, argv, params)) { + print_usage(argc, argv); + return 1; + } + + // init LLM + + llama_backend_init(); + llama_numa_init(params.numa); + + // initialize the model + + llama_model_params model_params = llama_model_params_from_gpt_params(params); + + llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); + + if (model == NULL) { + fprintf(stderr , "%s: error: unable to load model\n" , __func__); + return 1; + } + + llama_context_params ctx_params = llama_context_params_from_gpt_params(params); + + llama_context * ctx = llama_new_context_with_model(model, ctx_params); + + if (ctx == NULL) { + fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); + return 1; + } + + const unsigned int n_kv_max = llama_n_ctx(ctx); + + + const llama_vocab * vocab = llama_get_vocab(ctx); + llama_token bos = llama_token_bos_impl(*vocab); + //llama_token eos = llama_token_eos_impl(*vocab); + + const unsigned int n_vocab = llama_n_vocab(model); + + // decode in batches of ctx_params.n_batch tokens + auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) { + for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + + llama_batch batch_view = { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, + }; + + const int ret = llama_decode(ctx, batch_view); + if (ret != 0) { + LOG("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); + return false; + } + + llama_synchronize(ctx); + } + + return true; + }; + + const unsigned int pp = params.n_ubatch; + const unsigned int tg = params.n_ubatch / 4; + + if (!params.sweep_bench_output_jsonl) { + LOG("\n"); + LOG("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch); + LOG("\n"); + LOG("|%6s | %6s | %6s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s"); + LOG("|%6s-|-%6s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|\n", "------", "------", "------", "--------", "--------", "--------", "--------"); + } + + llama_batch batch = llama_batch_init(n_kv_max, 0, 1); + + // warm up + { + llama_batch_add(batch, bos, 0, { 0 }, false); + + if (!decode_helper(ctx, batch, ctx_params.n_batch)) { + LOG("%s: llama_decode() failed\n", __func__); + return 1; + } + } + + llama_batch_clear(batch); + llama_kv_cache_clear(ctx); + + for (unsigned int n_kv = 0; n_kv < n_kv_max; n_kv += params.n_ubatch) { + // clean up KV cache before generation + llama_kv_cache_seq_rm(ctx, 0, n_kv, -1); + + // first measure token generation performance at this context size + const auto t_tg_start = ggml_time_us(); + + for (unsigned int i = 0; i < tg; ++i) { + llama_batch_clear(batch); + llama_batch_add(batch, std::rand() % n_vocab, n_kv + i, { 0 }, true); + + if (!decode_helper(ctx, batch, ctx_params.n_batch)) { + LOG("%s: llama_decode() failed\n", __func__); + return 1; + } + } + + const auto t_tg_end = ggml_time_us(); + + // clean up KV cache after generation + llama_kv_cache_seq_rm(ctx, 0, n_kv, -1); + + // prepare batch of pp size for prompt processing performance measurement + llama_batch_clear(batch); + + for (unsigned int i = 0; i < pp; ++i) { + llama_batch_add(batch, std::rand() % n_vocab, n_kv + i, { 0 }, false); + } + batch.logits[batch.n_tokens - 1] = true; + + // measure prompt processing performance + const auto t_pp_start = ggml_time_us(); + + if (!decode_helper(ctx, batch, ctx_params.n_batch)) { + LOG("%s: llama_decode() failed\n", __func__); + return 1; + } + + const auto t_pp_end = ggml_time_us(); + + // calculate and print metrics + const float t_pp = (t_pp_end - t_pp_start) / 1000000.0f; + const float t_tg = (t_tg_end - t_tg_start) / 1000000.0f; + + const float speed_pp = pp / t_pp; + const float speed_tg = tg / t_tg; + + if(params.sweep_bench_output_jsonl) { + LOG( + "{\"n_kv_max\": %d, \"n_batch\": %d, \"n_ubatch\": %d, \"flash_attn\": %d, \"n_gpu_layers\": %d, \"n_threads\": %u, \"n_threads_batch\": %u, " + "\"pp\": %d, \"tg\": %d, \"n_kv\": %d, \"t_pp\": %f, \"speed_pp\": %f, \"t_tg\": %f, \"speed_tg\": %f }\n", + n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch, + pp, tg, n_kv, t_pp, speed_pp, t_tg, speed_tg + ); + } else { + LOG("|%6d | %6d | %6d | %8.3f | %8.2f | %8.3f | %8.2f |\n", pp, tg, n_kv, t_pp, speed_pp, t_tg, speed_tg); + } + } + + llama_batch_free(batch); + + llama_free(ctx); + llama_free_model(model); + + llama_backend_free(); + + return 0; +} From ac1d259b93eccfa7371c6b00c5749400ff2b2aea Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 23 Feb 2025 14:31:11 +0200 Subject: [PATCH 021/132] Fused MoE ffn_up and ffn_gate (#229) * Fusing MoE up * unary(gate) * Fusing MoE up * unary(gate): CUDA We get ~13% speedup for PP-512 and ~2% for TG-128 for DeepSeek-Lite * On CUDA also fuse MoE down * (up * unary(gate)) in case the MUL_MAT_ID op for the down experts is the next op in the graph. * Command line option to enable fused MoE up*unary(gate) * Add fmoe option to llama-bench * Adding forgotten gelu, relu, silu on ARM --------- Co-authored-by: Iwan Kawrakow --- common/common.cpp | 7 + common/common.h | 1 + examples/llama-bench/llama-bench.cpp | 33 +++- ggml/include/ggml.h | 10 + ggml/src/ggml-cuda.cu | 263 ++++++++++++++++++++++++++- ggml/src/ggml-cuda/unary.cu | 36 +++- ggml/src/ggml-cuda/unary.cuh | 2 + ggml/src/ggml.c | 155 +++++++++++----- ggml/src/iqk/iqk_mul_mat.cpp | 260 ++++++++++++++++++++++++++ ggml/src/iqk/iqk_mul_mat.h | 5 + include/llama.h | 1 + src/llama.cpp | 38 ++-- 12 files changed, 730 insertions(+), 81 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 6bf6e4f9..f975aee3 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -817,6 +817,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.mla_attn = true; return true; } + if (arg == "-fmoe" || arg == "--fused-moe") { + params.fused_moe_up_gate = true; + return true; + } if (arg == "-co" || arg == "--color") { params.use_color = true; return true; @@ -1466,6 +1470,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks }); options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" }); options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %s)", params.mla_attn ? "enabled" : "disabled" }); + options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" }); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" "in conversation mode, this will be used as system prompt\n" "(default: '%s')", params.prompt.c_str() }); @@ -2303,6 +2308,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.offload_kqv = !params.no_kv_offload; cparams.flash_attn = params.flash_attn; cparams.mla_attn = params.mla_attn; + cparams.fused_moe_up_gate = params.fused_moe_up_gate; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); @@ -3301,6 +3307,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); fprintf(stream, "mla_attn: %s # default: false\n", params.mla_attn ? "true" : "false"); + fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); diff --git a/common/common.h b/common/common.h index b5b67986..f86a58cb 100644 --- a/common/common.h +++ b/common/common.h @@ -175,6 +175,7 @@ struct gpt_params { bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention bool mla_attn = false; // MLA + bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 0222c213..b0790e20 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -241,6 +241,7 @@ struct cmd_params { bool verbose; bool warmup; bool repack = false; + bool fmoe = false; output_formats output_format; output_formats output_format_stderr; }; @@ -271,6 +272,7 @@ static const cmd_params cmd_params_defaults = { /* verbose */ false, /* warmup */ true, /* repack */ false, + /* fmoe */ false, /* output_format */ MARKDOWN, /* output_format_stderr */ NONE, }; @@ -307,6 +309,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0"); printf(" -w, --warmup <0|1> (default: %s)\n", cmd_params_defaults.warmup ? "1" : "0"); printf(" -rtr, --run-time-repack <0|1> (default: %s)\n", cmd_params_defaults.repack ? "1" : "0"); + printf(" -fmoe, --fused-moe <0|1> (default: %s)\n", cmd_params_defaults.fmoe? "1" : "0"); printf("\n"); printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n"); } @@ -607,6 +610,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } params.repack = std::stoi(argv[i]); + } else if (arg == "-fmoe" || arg == "--fused-moe") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.fmoe = std::stoi(argv[i]); } else { invalid_param = true; break; @@ -675,6 +684,7 @@ struct cmd_params_instance { bool use_mmap; bool embeddings; bool repack = false; + bool fmoe = false; llama_model_params to_llama_mparams() const { llama_model_params mparams = llama_model_default_params(); @@ -714,6 +724,7 @@ struct cmd_params_instance { cparams.offload_kqv = !no_kv_offload; cparams.flash_attn = flash_attn; cparams.mla_attn = mla_attn; + cparams.fused_moe_up_gate = fmoe; cparams.embeddings = embeddings; return cparams; @@ -765,6 +776,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .repack = */ params.repack, + /* .fmoe = */ params.fmoe, }; instances.push_back(instance); } @@ -794,6 +806,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .repack = */ params.repack, + /* .fmoe = */ params.fmoe, }; instances.push_back(instance); } @@ -823,6 +836,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .repack = */ params.repack, + /* .fmoe = */ params.fmoe, }; instances.push_back(instance); } @@ -852,6 +866,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .repack = */ params.repack, + /* .fmoe = */ params.fmoe, }; instances.push_back(instance); } @@ -892,6 +907,7 @@ struct test { bool use_mmap; bool embeddings; bool repack = false; + bool fmoe = false; int n_prompt; int n_gen; std::string test_time; @@ -922,6 +938,7 @@ struct test { use_mmap = inst.use_mmap; embeddings = inst.embeddings; repack = inst.repack; + fmoe = inst.fmoe; n_prompt = inst.n_prompt; n_gen = inst.n_gen; test_kind = inst.test_kind; @@ -1012,7 +1029,7 @@ struct test { "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", - "tensor_split", "use_mmap", "embeddings", "repack", + "tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts", "test", @@ -1033,7 +1050,8 @@ struct test { } if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || - field == "flash_attn" || field == "mla_attn" || field == "use_mmap" || field == "embeddings" || field == "repack") { + field == "flash_attn" || field == "mla_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || + field == "fused_moe") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -1068,7 +1086,7 @@ struct test { std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn), - tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), + tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(fmoe), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), std::to_string(avg_ts()), std::to_string(stdev_ts()), @@ -1240,6 +1258,9 @@ struct markdown_printer : public printer { if (field == "repack") { return 3; } + if (field == "fused_moe") { + return 4; + } if (field == "test") { return 13; } @@ -1277,6 +1298,9 @@ struct markdown_printer : public printer { if (field == "repack") { return "rtr"; } + if (field == "fused_moe") { + return "fmoe"; + } if (field == "embeddings") { return "embd"; } @@ -1338,6 +1362,9 @@ struct markdown_printer : public printer { if (params.repack != cmd_params_defaults.repack) { fields.emplace_back("repack"); } + if (params.fmoe != cmd_params_defaults.fmoe) { + fields.emplace_back("fused_moe"); + } fields.emplace_back("test"); fields.emplace_back("t/s"); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index d2131a15..d12b90d0 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -567,6 +567,7 @@ extern "C" { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID, GGML_OP_OUT_PROD, + GGML_OP_MOE_FUSED_UP_GATE, GGML_OP_SCALE, GGML_OP_SET, @@ -1320,6 +1321,15 @@ extern "C" { struct ggml_tensor * b, struct ggml_tensor * ids); + // MoE up + gate + unary + GGML_API struct ggml_tensor * ggml_moe_up_gate( + struct ggml_context * ctx, + struct ggml_tensor * as_up, + struct ggml_tensor * as_gate, + struct ggml_tensor * b, + struct ggml_tensor * ids, + enum ggml_unary_op op); + // A: m columns, n rows, // B: p columns, n rows, // result is m columns, p rows diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index e38e9568..26d06d56 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2195,7 +2195,252 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } } -static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) { +static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) { + const ggml_tensor * src0_1 = dst->src[0]; + const ggml_tensor * src0_2 = dst->src[1]; + const ggml_tensor * src0 = src0_1; + const ggml_tensor * src1 = dst->src[2]; + const ggml_tensor * ids = dst->src[3]; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0_1->buffer) && "mul_mat_id does not support split buffers"); + GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0_2->buffer) && "mul_mat_id does not support split buffers"); + + cudaStream_t stream = ctx.stream(); + + const int64_t n_as = ne02; + const int64_t n_ids = ids->ne[0]; + + std::vector ids_host(ggml_nbytes(ids)); + const char * ids_dev = (const char *) ids->data; + CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + ggml_tensor src0_1_row = *src0_1; + ggml_tensor src0_2_row = *src0_2; + ggml_tensor src1_row = *src1; + ggml_tensor dst_row = *dst; + ggml_tensor final_dst; + ggml_tensor final_src; + + char * src0_1_original = (char *) src0_1->data; + char * src0_2_original = (char *) src0_2->data; + char * src1_original = (char *) src1->data; + char * dst_original = (char *) dst->data; + + src0_1_row.ne[2] = 1; + src0_1_row.ne[3] = 1; + src0_1_row.nb[3] = nb02; + src0_2_row.ne[2] = 1; + src0_2_row.ne[3] = 1; + src0_2_row.nb[3] = nb02; + + src1_row.ne[1] = 1; + src1_row.ne[2] = 1; + src1_row.ne[3] = 1; + src1_row.nb[2] = nb11; + src1_row.nb[3] = nb11; + + dst_row.ne[1] = 1; + dst_row.ne[2] = 1; + dst_row.ne[3] = 1; + dst_row.nb[2] = nb1; + dst_row.nb[3] = nb1; + + bool fuse_down = false; + if (next && next->op == GGML_OP_MUL_MAT_ID) { + //printf("Fusing MoE down gemm\n"); + fuse_down = true; + final_dst = *next; + final_dst.ne[1] = final_dst.ne[2] = final_dst.ne[3] = 1; + final_dst.nb[2] = final_dst.nb[3] = final_dst.nb[1]; + final_src = *next->src[0]; + //printf("next->src[0]: %s, %d x %d x %d x %d and %d x %d x %d x %d\n", ggml_type_name(next->src[0]->type), + // (int)next->src[0]->ne[0], (int)next->src[0]->ne[1], (int)next->src[0]->ne[2], (int)next->src[0]->ne[3], + // (int)next->src[0]->nb[0], (int)next->src[0]->nb[1], (int)next->src[0]->nb[2], (int)next->src[0]->nb[3]); + final_src.ne[2] = final_src.ne[3] = 1; + final_src.nb[3] = final_src.nb[2]; + } + + if (ne12 == 1) { + ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]); + ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]); + if (fuse_down) { + final_dst.src[1] = &dst_row; + } + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + const int64_t i11 = id % ne11; + const int64_t i12 = iid1; + + const int64_t i1 = id; + const int64_t i2 = i12; + + src0_1_row.data = src0_1_original + i02*nb02; + src0_2_row.data = src0_2_original + i02*nb02; + src1_row.data = src1_original + i11*nb11 + i12*nb12; + //dst_row.data = dst_original + i1*nb1 + i2*nb2; + + dst_row.data = dst_up_contiguous.get(); + ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row); + CUDA_CHECK(cudaGetLastError()); + + dst_row.data = dst_gate_contiguous.get(); + ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row); + CUDA_CHECK(cudaGetLastError()); + + if (fuse_down) { + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0], + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); + CUDA_CHECK(cudaGetLastError()); + + final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2]; + final_dst.data = (char *)next->data + i1*next->nb[1] + i2*next->nb[2]; + ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst); + CUDA_CHECK(cudaGetLastError()); + + } else { + + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0], + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)(dst_original + i1*nb1 + i2*nb2)); + CUDA_CHECK(cudaGetLastError()); + + } + } + } + } else { + ggml_cuda_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); + ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + ggml_cuda_pool_alloc final_dst_contiguous(ctx.pool()); + if (fuse_down) { + final_dst.data = final_dst_contiguous.alloc(ggml_nelements(next)); + final_dst.src[1] = &dst_row; + } + + src1_row.data = src1_contiguous.get(); + + bool first = false; //true; + + for (int64_t i02 = 0; i02 < n_as; i02++) { + int64_t num_src1_rows = 0; + + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + + GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + + if (row_id_i != i02) { + continue; + } + + num_src1_rows++; + } + } + + if (num_src1_rows == 0) { + continue; + } + + ggml_cuda_pool_alloc dev_cur_src1_row(ctx.pool(), 1); + ggml_cuda_pool_alloc dev_row_mapping(ctx.pool(), num_src1_rows); + CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream)); + + { + dim3 block_dims(std::min((unsigned int)ne10, 768u)); + dim3 grid_dims(ids->ne[1], n_ids); + k_copy_src1_to_contiguous<<>>( + src1_original, src1_contiguous.get(), + dev_cur_src1_row.get(), dev_row_mapping.get(), + ids_dev, i02, ids->nb[1], ids->nb[0], + ne11, ne10, + nb11, nb12); + CUDA_CHECK(cudaGetLastError()); + } + + src0_1_row.data = src0_1_original + i02*nb02; + src0_2_row.data = src0_2_original + i02*nb02; + + GGML_ASSERT(nb11 == sizeof(float)*ne10); + GGML_ASSERT(nb1 == sizeof(float)*ne0); + + src1_row.ne[1] = num_src1_rows; + src1_row.nb[1] = nb11; + src1_row.nb[2] = num_src1_rows*nb11; + src1_row.nb[3] = num_src1_rows*nb11; + + dst_row.ne[1] = num_src1_rows; + dst_row.nb[1] = nb1; + dst_row.nb[2] = num_src1_rows*nb1; + dst_row.nb[3] = num_src1_rows*nb1; + + dst_row.data = dst_up_contiguous.get(); + ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row); + CUDA_CHECK(cudaGetLastError()); + + dst_row.data = dst_gate_contiguous.get(); + ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row); + CUDA_CHECK(cudaGetLastError()); + + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); + CUDA_CHECK(cudaGetLastError()); + + if (fuse_down) { + + final_dst.ne[1] = num_src1_rows; + final_dst.nb[1] = final_dst.ne[0]*sizeof(float); + final_dst.nb[2] = final_dst.nb[3] = num_src1_rows*final_dst.nb[1]; + final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2]; + if (first) { + printf("Fusing down for %d rows: (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n", (int)num_src1_rows, + (int)next->ne[0], (int)next->ne[1], (int)next->ne[2], (int)next->ne[3], + (int)next->src[0]->ne[0], (int)next->src[0]->ne[1], (int)next->src[0]->ne[2], (int)next->src[0]->ne[3], + (int)next->src[1]->ne[0], (int)next->src[1]->ne[1], (int)next->src[1]->ne[2], (int)next->src[1]->ne[3]); + printf(" using (%d x %d x %d x %d) = (%d x %d x %d x %d) * (%d x %d x %d x %d)\n", + (int)final_dst.ne[0], (int)final_dst.ne[1], (int)final_dst.ne[2], (int)final_dst.ne[3], + (int)final_src.ne[0], (int)final_src.ne[1], (int)final_src.ne[2], (int)final_src.ne[3], + (int)dst_row.ne[0], (int)dst_row.ne[1], (int)dst_row.ne[2], (int)dst_row.ne[3]); + first = false; + } + ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst); + //ggml_cuda_mul_mat(ctx, next->src[0], &dst_row, &final_dst); + CUDA_CHECK(cudaGetLastError()); + + dim3 block_dims(std::min((unsigned int)next->ne[0], 768u)); + dim3 grid_dims(num_src1_rows); + k_copy_dst_from_contiguous<<>>( + (char *)next->data, final_dst_contiguous.get(), + dev_row_mapping.get(), + next->ne[0], + next->nb[1], next->nb[2]); + CUDA_CHECK(cudaGetLastError()); + + } + else { + + dim3 block_dims(std::min((unsigned int)ne0, 768u)); + dim3 grid_dims(num_src1_rows); + k_copy_dst_from_contiguous<<>>( + dst_original, dst_gate_contiguous.get(), + dev_row_mapping.get(), + ne0, + nb1, nb2); + CUDA_CHECK(cudaGetLastError()); + } + } + } + + return fuse_down; +} + +static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, struct ggml_tensor * next, bool& skip_next) { // why is this here instead of mul_mat? if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) { ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device); @@ -2309,6 +2554,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_MUL_MAT_ID: ggml_cuda_mul_mat_id(ctx, dst); break; + case GGML_OP_MOE_FUSED_UP_GATE: + skip_next = ggml_cuda_up_gate_unary(ctx, dst, next); + break; case GGML_OP_SCALE: ggml_cuda_op_scale(ctx, dst); break; @@ -2595,7 +2843,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t #endif } - if (node->op == GGML_OP_MUL_MAT_ID) { + if (node->op == GGML_OP_MUL_MAT_ID || node->op == GGML_OP_MOE_FUSED_UP_GATE) { use_cuda_graph = false; // This node type is not supported by CUDA graph capture #ifndef NDEBUG GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to mul_mat_id\n", __func__); @@ -2666,6 +2914,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t if (!use_cuda_graph || cuda_graph_update_required) { for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; + ggml_tensor * next = i < cgraph->n_nodes-1 ? cgraph->nodes[i+1] : nullptr; if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { continue; @@ -2680,11 +2929,13 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t } #endif - bool ok = ggml_cuda_compute_forward(*cuda_ctx, node); + bool skip_next = false; + bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, next, skip_next); if (!ok) { GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); } GGML_ASSERT(ok); + if (skip_next) ++i; } } @@ -2809,9 +3060,13 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_FUSED_MUL_UNARY: return ggml_is_contiguous(op->src[0]); case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: + case GGML_OP_MOE_FUSED_UP_GATE: { struct ggml_tensor * a = op->src[0]; - struct ggml_tensor * b = op->src[1]; + struct ggml_tensor * b = op->op == GGML_OP_MOE_FUSED_UP_GATE ? op->src[2] : op->src[1]; + if (op->op == GGML_OP_MOE_FUSED_UP_GATE && a->type != op->src[1]->type) { + return false; + } if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) { return false; } diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 8ffddd6d..c422abbc 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -297,6 +297,19 @@ void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { swiglu_f32_cuda(src0_d, dst_d, ggml_nelements(dst), dst->ne[0], src0->nb[1]/sizeof(float), stream); } +void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op, + int64_t nelements, const float * src0_d, const float * src1_d, float * dst_d) { + + cudaStream_t stream = ctx.stream(); + + switch (op) { + case GGML_UNARY_OP_SILU: fused_mul_silu_f32_cuda(src0_d, src1_d, dst_d, nelements, stream); break; + case GGML_UNARY_OP_RELU: fused_mul_relu_f32_cuda(src0_d, src1_d, dst_d, nelements, stream); break; + case GGML_UNARY_OP_GELU: fused_mul_gelu_f32_cuda(src0_d, src1_d, dst_d, nelements, stream); break; + default: GGML_ASSERT(false); + } +} + void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -304,19 +317,22 @@ void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src0, src1)); - cudaStream_t stream = ctx.stream(); ggml_unary_op op = (ggml_unary_op)dst->op_params[0]; - const float * src0_d = (const float *)src0->data; - const float * src1_d = (const float *)src1->data; - float * dst_d = (float *)dst->data; + ggml_fused_mul_unary(ctx, op, ggml_nelements(dst), (const float *)src0->data, (const float *)src1->data, (float *)dst->data); - switch (op) { - case GGML_UNARY_OP_SILU: fused_mul_silu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; - case GGML_UNARY_OP_RELU: fused_mul_relu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; - case GGML_UNARY_OP_GELU: fused_mul_gelu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; - default: GGML_ASSERT(false); - } + //cudaStream_t stream = ctx.stream(); + + //const float * src0_d = (const float *)src0->data; + //const float * src1_d = (const float *)src1->data; + //float * dst_d = (float *)dst->data; + + //switch (op) { + // case GGML_UNARY_OP_SILU: fused_mul_silu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; + // case GGML_UNARY_OP_RELU: fused_mul_relu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; + // case GGML_UNARY_OP_GELU: fused_mul_gelu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; + // default: GGML_ASSERT(false); + //} } void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 0235a319..e55c4262 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -36,5 +36,7 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op, + int64_t nelements, const float * x, const float * y, float * z); void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index eb39d574..8efe2653 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3845,6 +3845,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "MUL_MAT", "MUL_MAT_ID", "OUT_PROD", + "MOE_FUSED_UP_GATE", "SCALE", "SET", @@ -3904,7 +3905,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); +static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3938,6 +3939,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "X*Y", "X[i]*Y", "X*Y", + "X*Y1&X*Y2", "x*v", "y-\\>view(x)", @@ -3997,7 +3999,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); +static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -6768,6 +6770,51 @@ struct ggml_tensor * ggml_mul_mat_id( return result; } +struct ggml_tensor * ggml_moe_up_gate( + struct ggml_context * ctx, + struct ggml_tensor * as_up, + struct ggml_tensor * as_gate, + struct ggml_tensor * b, + struct ggml_tensor * ids, + enum ggml_unary_op op) { + if (as_up->type != as_gate->type || !ggml_are_same_shape(as_up, as_gate)) { + struct ggml_tensor * result_up = ggml_mul_mat_id(ctx, as_up, b, ids); + struct ggml_tensor * result_gate = ggml_mul_mat_id(ctx, as_gate, b, ids); + return ggml_fused_mul_unary(ctx, result_gate, result_up, op); + } + GGML_ASSERT(!ggml_is_transposed(as_up)); + GGML_ASSERT(!ggml_is_transposed(as_gate)); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + GGML_ASSERT(as_up->ne[3] == 1); // as is 3d (one matrix per expert) + GGML_ASSERT(b->ne[3] == 1); // b is 3d + GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d + GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row + GGML_ASSERT(as_up->ne[0] == b->ne[0]); // can_mul_mat + GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast + + bool is_node = false; + + if (as_up->grad || as_gate->grad || b->grad) { + is_node = true; + } + + const int64_t ne[4] = { as_up->ne[1], ids->ne[0], b->ne[2], 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_MOE_FUSED_UP_GATE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = as_up; + result->src[1] = as_gate; + result->src[2] = b; + result->src[3] = ids; + + ggml_set_op_params_i32(result, 0, (int32_t) op); + + return result; +} + + // ggml_out_prod struct ggml_tensor * ggml_out_prod( @@ -14584,20 +14631,17 @@ IQK_MulMat_Not_Available:; #if GGML_USE_IQK_MULMAT static void ggml_compute_forward_mul_mat_id_up_gate( const struct ggml_compute_params * params, - struct ggml_tensor * dst1, - struct ggml_tensor * dst2) { + struct ggml_tensor * dst) { - GGML_ASSERT(dst1->src[1] == dst2->src[1]); - GGML_ASSERT(dst1->src[2] == dst2->src[2]); - GGML_ASSERT(dst1->src[0]->type == dst2->src[0]->type); - GGML_ASSERT(dst1->type == GGML_TYPE_F32 && dst2->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == dst->src[1]->type); + GGML_ASSERT(ggml_are_same_shape(dst->src[0], dst->src[1])); + GGML_ASSERT(dst->type == GGML_TYPE_F32); - const struct ggml_tensor * src1 = dst1->src[1]; - const struct ggml_tensor * ids = dst1->src[2]; - const struct ggml_tensor * src0_1 = dst1->src[0]; - const struct ggml_tensor * src0_2 = dst2->src[0]; - const struct ggml_tensor * src0 = src0_1; - const struct ggml_tensor * dst = dst1; // so GGML_TENSOR_BINARY_OP_LOCALS works + const struct ggml_tensor * src1 = dst->src[2]; + const struct ggml_tensor * ids = dst->src[3]; + const struct ggml_tensor * src0_1 = dst->src[0]; + const struct ggml_tensor * src0_2 = dst->src[1]; + const struct ggml_tensor * src0 = src0_1; // so GGML_TENSOR_BINARY_OP_LOCALS works GGML_TENSOR_BINARY_OP_LOCALS @@ -14680,6 +14724,9 @@ static void ggml_compute_forward_mul_mat_id_up_gate( ggml_barrier(params->shared); + + // so GGML_TENSOR_BINARY_OP_LOCALS works + // compute each matrix multiplication in sequence for (int cur_a = 0; cur_a < n_as; ++cur_a) { const int64_t cne1 = matrix_row_counts[cur_a]; @@ -14696,28 +14743,34 @@ static void ggml_compute_forward_mul_mat_id_up_gate( const int64_t nr0 = ne01; // src0 rows const int64_t nr1 = cne1; // src1 rows + // + if (!iqk_moe_fused_up_gate(nr0, nr1, ne00, ne11, dst->op_params[0], + type, src0_1_cur, src0_2_cur, nb01, + vec_dot_type, (const char *)wdata, row_size, + (float *)dst->data, nb1, nb2, + matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); - if (nth%2 == 0) { - const char * src0_d = ith%2 == 0 ? src0_1_cur : src0_2_cur; - void * dst_d = ith%2 == 0 ? dst1->data : dst2->data; - if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, - type, src0_d, nb01, - vec_dot_type, (const char *)wdata, row_size, - (float *)dst_d, nb1, nb2, - matrix_rows + cur_a*ne12, ith/2, nth/2)) GGML_ABORT("fatal error"); - - } else { - if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, - src0_1->type, (const char *)src0_1_cur, nb01, - vec_dot_type, (const char *)wdata, row_size, - (float *)dst1->data, nb1, nb2, - matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); - if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, - src0_2->type, (const char *)src0_2_cur, nb01, - vec_dot_type, (const char *)wdata, row_size, - (float *)dst2->data, nb1, nb2, - matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); - } +// if (nth%2 == 0) { +// const char * src0_d = ith%2 == 0 ? src0_1_cur : src0_2_cur; +// void * dst_d = ith%2 == 0 ? dst1->data : dst2->data; +// if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, +// type, src0_d, nb01, +// vec_dot_type, (const char *)wdata, row_size, +// (float *)dst_d, nb1, nb2, +// matrix_rows + cur_a*ne12, ith/2, nth/2)) GGML_ABORT("fatal error"); +// +// } else { +// if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, +// src0_1->type, (const char *)src0_1_cur, nb01, +// vec_dot_type, (const char *)wdata, row_size, +// (float *)dst1->data, nb1, nb2, +// matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); +// if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, +// src0_2->type, (const char *)src0_2_cur, nb01, +// vec_dot_type, (const char *)wdata, row_size, +// (float *)dst2->data, nb1, nb2, +// matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); +// } } #undef MMID_MATRIX_ROW @@ -19152,6 +19205,7 @@ static void ggml_compute_forward_cross_entropy_loss_back( static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_tensor * next) { GGML_ASSERT(params); + GGML_UNUSED(next); if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) { return false; @@ -19269,16 +19323,12 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_MUL_MAT_ID: { -#if GGML_USE_IQK_MULMAT - if (next && next->op == GGML_OP_MUL_MAT_ID && tensor->src[1] == next->src[1] && - tensor->src[0]->type == next->src[0]->type) { - ggml_compute_forward_mul_mat_id_up_gate(params, tensor, next); - skip_next = true; - break; - } -#endif ggml_compute_forward_mul_mat_id(params, tensor); } break; + case GGML_OP_MOE_FUSED_UP_GATE: + { + ggml_compute_forward_mul_mat_id_up_gate(params, tensor); + } break; case GGML_OP_OUT_PROD: { ggml_compute_forward_out_prod(params, tensor); @@ -20036,6 +20086,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ABORT("fatal error"); // TODO: not implemented } + case GGML_OP_MOE_FUSED_UP_GATE: + { + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_OUT_PROD: { GGML_ABORT("fatal error"); // TODO: not implemented @@ -21046,6 +21100,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_CONCAT: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: + case GGML_OP_MOE_FUSED_UP_GATE: case GGML_OP_OUT_PROD: { n_tasks = n_threads; @@ -21249,6 +21304,20 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += n_as * sizeof(int64_t); // matrix_row_counts cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows } break; + case GGML_OP_MOE_FUSED_UP_GATE: + { + cur = 0; + const struct ggml_tensor * src0 = node->src[0]; + const struct ggml_tensor * src2 = node->src[2]; + const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type; + if (src2->type != vec_dot_type) { + cur += ggml_row_size(vec_dot_type, node->src[1]->ne[0]) * ggml_nrows(node->src[1]); + } + const int n_as = src0->ne[2]; + cur += GGML_PAD(cur, sizeof(int64_t)); // align + cur += n_as * sizeof(int64_t); // matrix_row_counts + cur += n_as * src2->ne[2] * sizeof(int64_t); // matrix_rows + } break; case GGML_OP_OUT_PROD: { if (ggml_is_quantized(node->src[0]->type)) { diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index ed7309cd..0f7cd1e5 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -217,6 +217,118 @@ struct MulMat { funcs[n_left-1](n, vx, bx, info, nrc_x); } } + inline void gelu(int n, const float * src, float * dst); + inline void relu(int n, const float * src, float * dst); + inline void silu(int n, const float * src, float * dst); + inline void activate(ggml_unary_op op, int n, const float * src, float * dst) { + if (op == GGML_UNARY_OP_GELU) gelu(n, src, dst); + else if (op == GGML_UNARY_OP_RELU) relu(n, src, dst); + else if (op == GGML_UNARY_OP_SILU) silu(n, src, dst); + else GGML_ABORT("fatal error"); + } + inline void mul_mat_up_gate_NxM(int n, const void * vx_up, const void * vx_gate, size_t bx, DataInfo& info, int nrc_x, int nrc_y, int unary_op) { +#ifdef __aarch64__ + constexpr int k_x_step = 64; //8192; // Tiling does not seem to help on my M2 Max (but difference to tiling is small) +#else + constexpr int k_x_step = 64; // This works best on my Ryzen-7950X (but differences to other tile size are small) +#endif + auto op = ggml_unary_op(unary_op); + float tmp[k_x_step*16]; + if (func16 && nrc_y >= 16) { + int n_step = (nrc_y - info.cur_y)/16; + for (int ix = 0; ix < nrc_x; ix += k_x_step) { + auto this_info = info; + this_info.s += ix; + int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; + for (int iy = 0; iy < n_step; ++iy) { + func16(n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < 16; ++ky) { + activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); + } + func16(n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < 16; ++ky) { + auto result = this_info.dst_row(ky); + for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; + } + this_info.cur_y += 16; + } + } + info.cur_y += 16 * n_step; + if (info.cur_y == nrc_y) return; + } + int ny = funcs.size(); + while (!funcs[ny-1] && ny > 0) --ny; + int n_left = nrc_y - info.cur_y; + int n_step = n_left/ny; + if (n_step > 0) { + if (n_step*ny != n_left) { + ++n_step; + int ny1 = n_left/n_step; + int ny2 = ny1 + 1; + int my1 = n_step*ny2 - n_left; + int my2 = n_step - my1; + for (int ix = 0; ix < nrc_x; ix += k_x_step) { + auto this_info = info; + this_info.s += ix; + int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; + for (int iy = 0; iy < my1; ++iy) { + funcs[ny1-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny1; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); + funcs[ny1-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny1; ++ky) { + auto result = this_info.dst_row(ky); + for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; + } + this_info.cur_y += ny1; + } + for (int iy = 0; iy < my2; ++iy) { + funcs[ny2-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny2; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); + funcs[ny2-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny2; ++ky) { + auto result = this_info.dst_row(ky); + for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; + } + this_info.cur_y += ny2; + } + } + info.cur_y += n_left; + } + else { + for (int ix = 0; ix < nrc_x; ix += k_x_step) { + auto this_info = info; + this_info.s += ix; + int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; + for (int iy = 0; iy < n_step; ++iy) { + funcs[ny-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); + funcs[ny-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < ny; ++ky) { + auto result = this_info.dst_row(ky); + for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; + } + this_info.cur_y += ny; + } + } + info.cur_y += ny * n_step; + } + } + n_left = nrc_y - info.cur_y; + if (n_left > 0) { + for (int ix = 0; ix < nrc_x; ix += k_x_step) { + auto this_info = info; + this_info.s += ix; + int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; + funcs[n_left-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < n_left; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step); + funcs[n_left-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x); + for (int ky = 0; ky < n_left; ++ky) { + auto result = this_info.dst_row(ky); + for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j]; + } + } + } + } static bool prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny); static inline int num_rows(ggml_type type) { #ifdef HAVE_FANCY_SIMD @@ -414,6 +526,34 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, return true; } +bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op, + int typeA, const void * Aup, const void * Agate, long strideA, + int typeB, const void * B, long strideB, + float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) { + + const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping; + assert(row_mapping != nullptr); + + MulMat mm; + if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) { + return false; + } + size_t row_size_qx = strideA; + size_t row_size_qy = strideB; + auto num_rows = MulMat::num_rows(ggml_type(typeA)); + GGML_ASSERT(Nx%num_rows == 0); + auto nrc_x = (Nx/num_rows + nth - 1)/nth; + auto first_x = ith*nrc_x; + if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x; + first_x *= num_rows; + nrc_x *= num_rows; + DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), + row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)}; + mm.mul_mat_up_gate_NxM(ne00, (const char *)Aup + row_size_qx*first_x, (const char *)Agate + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny, unary_op); + return true; +} + + namespace { inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) { @@ -14660,6 +14800,45 @@ inline float32x4_t v_tanh(float16x8_t x) { auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x))); return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2)); } +inline float32x4_t v_silu(float32x4_t x) { + const float32x4_t one = vdupq_n_f32(1.0f); + const float32x4_t zero = vdupq_n_f32(0.0f); + const float32x4_t neg_x = vsubq_f32(zero, x); + const float32x4_t exp_neg_x = v_expf(neg_x); + const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); + return vdivq_f32(x, one_plus_exp_neg_x); +} +inline float32x4_t v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) { + const float32x4_t one = vdupq_n_f32(1.0f); + float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x)); + arg = vmulq_f32(arg, vmulq_f32(x, c2)); + float32x4_t exp_arg = v_expf(arg); + float32x4_t gelu = vmulq_f32(x, vdivq_f32(exp_arg, vaddq_f32(exp_arg, one))); + uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f)); + return vbslq_f32(mask, x, gelu); +} + +void MulMat::gelu(int n, const float * x, float * y) { + constexpr float GELU_COEF_A = 0.044715f; + constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + int i = 0; + auto c1 = vdupq_n_f32(GELU_COEF_A); + auto c2 = vdupq_n_f32(2.f*SQRT_2_OVER_PI); + for (; i + 3 < n; i += 4) { + vst1q_f32(y + i, v_gelu(vld1q_f32(x + i), c1, c2)); + } + for (; i < n; ++i) y[i] = 0.5f*x[i]*(1.0f + tanhf(SQRT_2_OVER_PI*x[i]*(1.0f + GELU_COEF_A*x[i]*x[i]))); +} + +void MulMat::silu(int n, const float * x, float * y) { + int i = 0; + for (; i + 3 < n; i += 4) vst1q_f32(y + i, v_silu(vld1q_f32(x + i))); + for (; i < n; ++i) y[i] = x[i]/(1.0f + expf(-x[i])); +} + +void MulMat::relu(int n, const float * x, float * y) { + for (int j = 0; j < n; ++j) y[j] = x[j] > 0 ? x[j] : 0; +} #endif #if defined(__AVX512F__) && defined(__AVX512DQ__) @@ -14702,6 +14881,24 @@ inline __m512 v_tanh(__m512 x) { const __m512 res = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one)); return _mm512_mask_blend_ps(mask, res, one); } +inline __m512 v_gelu(__m512 x, __m512 c1, __m512 c2) { + const __m512 one = _mm512_set1_ps(1.0f); + __m512 arg = _mm512_fmadd_ps(x, _mm512_mul_ps(c1, x), one); + //__m512 arg = _mm512_add_ps(one, _mm512_mul_ps(_mm512_mul_ps(x, x), c1)); + arg = _mm512_mul_ps(arg, _mm512_mul_ps(c2, x)); + const __mmask16 mask = _mm512_cmp_ps_mask(arg, _mm512_set1_ps(30.f), _CMP_GT_OQ); + const __m512 exp_arg = v_expf(arg); + const __m512 ratio = _mm512_div_ps(exp_arg, _mm512_add_ps(exp_arg, one)); + return _mm512_mul_ps(x, _mm512_mask_blend_ps(mask, ratio, one)); +} +inline static __m512 v_silu(__m512 x) { + const __m512 one = _mm512_set1_ps(1); + const __m512 zero = _mm512_setzero_ps(); + const __m512 neg_x = _mm512_sub_ps(zero, x); + const __m512 exp_neg_x = v_expf(neg_x); + const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); + return _mm512_div_ps(x, one_plus_exp_neg_x); +} #endif #if defined(__AVX2__) && defined(__FMA__) @@ -14755,6 +14952,61 @@ inline __m256 v_tanh(__m256 x) { const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ); return _mm256_or_ps(_mm256_and_ps(mask, one), _mm256_andnot_ps(mask, res)); } +inline static __m256 v_gelu(__m256 x, __m256 c1, __m256 c2) { + const __m256 one = _mm256_set1_ps(1.0f); + const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ); + __m256 arg = _mm256_add_ps(one, _mm256_mul_ps(_mm256_mul_ps(x, x), c1)); + arg = _mm256_mul_ps(arg, _mm256_mul_ps(x, c2)); + __m256 exp_arg = v_expf(arg); + __m256 gelu = _mm256_mul_ps(x, _mm256_div_ps(exp_arg, _mm256_add_ps(exp_arg, one))); + return _mm256_or_ps(_mm256_and_ps(mask, x), _mm256_andnot_ps(mask, gelu)); +} +inline static __m256 v_silu(__m256 x) { + const __m256 one = _mm256_set1_ps(1); + const __m256 zero = _mm256_setzero_ps(); + const __m256 neg_x = _mm256_sub_ps(zero, x); + const __m256 exp_neg_x = v_expf(neg_x); + const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); + return _mm256_div_ps(x, one_plus_exp_neg_x); +} + +void MulMat::gelu(int n, const float * x, float * y) { + constexpr float GELU_COEF_A = 0.044715f; + constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + //GGML_ASSERT(n%8 == 0); + int i = 0; +#if defined __AVX512F__ && defined __AVX512DQ__ + { + __m512 c1 = _mm512_set1_ps(GELU_COEF_A); + __m512 c2 = _mm512_set1_ps(2.f*SQRT_2_OVER_PI); + for (; i + 15 < n; i += 16) _mm512_storeu_ps(y + i, v_gelu(_mm512_loadu_ps(x + i), c1, c2)); + } +#endif +#if defined __AVX2__ && defined __FMA__ + if (i + 7 < n) { + __m256 c1 = _mm256_set1_ps(GELU_COEF_A); + __m256 c2 = _mm256_set1_ps(2.f*SQRT_2_OVER_PI); + for (; i + 7 < n; i += 8) _mm256_storeu_ps(y + i, v_gelu(_mm256_loadu_ps(x + i), c1, c2)); + + } +#endif + for (; i < n; ++i) y[i] = 0.5f*x[i]*(1.0f + tanhf(SQRT_2_OVER_PI*x[i]*(1.0f + GELU_COEF_A*x[i]*x[i]))); +} + +void MulMat::silu(int n, const float * x, float * y) { + int i = 0; +#if defined __AVX512F__ && defined __AVX512DQ__ + for (; i + 15 < n; i += 16) _mm512_storeu_ps(y + i, v_silu(_mm512_loadu_ps(x + i))); +#endif +#if defined __AVX2__ && defined __FMA__ + for (; i + 7 < n; i += 8) _mm256_storeu_ps(y + i, v_silu(_mm256_loadu_ps(x + i))); +#endif + for (; i < n; ++i) y[i] = x[i]/(1.0f + expf(-x[i])); +} + +void MulMat::relu(int n, const float * x, float * y) { + for (int j = 0; j < n; ++j) y[j] = x[j] > 0 ? x[j] : 0; +} #endif } // namespace @@ -17107,6 +17359,14 @@ bool iqk_mul_mat_moe(long, long, long, int, int, const void *, long, int, const return false; } +bool iqk_moe_fused_up_gate(long /*Nx*/, long /*Ny*/, long /*ne00*/, int /*ne11*/, int /*unary_op*/, + int /*typeA*/, const void * /*Aup*/, const void * /*Agate*/, long /*strideA*/, + int /*typeB*/, const void * /*B*/, long /*strideB*/, + float * /*C*/, long /*nb1*/, long /*nb2*/, const void * /*vrow_mapping*/, int /*ith*/, int /*nth*/) { + return false; +} + + bool iqk_flash_attn_noalibi([[maybe_unused]] int int_type_k, // type of k [[maybe_unused]] int int_type_v, // type of v [[maybe_unused]] int D, // head size diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index d5f340b2..767f89cf 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -28,6 +28,11 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeB, const void * B, long strideB, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth); +bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op, + int typeA, const void * Aup, const void * Agate, long strideA, + int typeB, const void * B, long strideB, + float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth); + bool iqk_flash_attn_noalibi(int type_k, // type of k int type_v, // type of v int Dk, // K head size diff --git a/include/llama.h b/include/llama.h index b5ad65e7..23e32642 100644 --- a/include/llama.h +++ b/include/llama.h @@ -377,6 +377,7 @@ extern "C" { bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool flash_attn; // whether to use flash attention [EXPERIMENTAL] bool mla_attn; // whether to use MLA attention [EXPERIMENTAL] + bool fused_moe_up_gate; // whether to use fused MoE up/down op [EXPERIMENTAL] // Abort callback // if it returns true, execution of llama_decode() will be aborted diff --git a/src/llama.cpp b/src/llama.cpp index 28e887ee..eed7aa61 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2516,6 +2516,7 @@ struct llama_cparams { bool offload_kqv; bool flash_attn; bool mla_attn; + bool fused_moe_up_gate; enum llama_pooling_type pooling_type; @@ -8628,30 +8629,20 @@ llm_expert_gating_func_type gating_op, } cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens); - ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - cb(up, "ffn_moe_up", il); - ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - cb(gate, "ffn_moe_gate", il); + ggml_tensor * par; + if (lctx.cparams.fused_moe_up_gate) { + par = ggml_moe_up_gate(ctx, up_exps, gate_exps, cur, selected_experts, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU); + } else { + ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(up, "ffn_moe_up", il); - // This is equivalent to the commented out code below - ggml_tensor * par = ggml_fused_mul_unary(ctx, gate, up, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU); + ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(gate, "ffn_moe_gate", il); - //switch (type_op) { - // case LLM_FFN_SILU: - // { - // gate = ggml_silu(ctx, gate); - // cb(gate, "ffn_moe_silu", il); - // } break; - // case LLM_FFN_GELU: - // { - // gate = ggml_gelu(ctx, gate); - // cb(gate, "ffn_moe_gelu", il); - // } break; - // default: - // GGML_ABORT("fatal error"); - //} - //ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens] + // This is equivalent to the commented out code below + par = ggml_fused_mul_unary(ctx, gate, up, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU); + } cb(par, "ffn_moe_gate_par", il); ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] @@ -8907,6 +8898,7 @@ struct llm_build_context { const bool flash_attn; const bool mla_attn; + const bool fused_moe_up_gate; const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; @@ -8958,6 +8950,7 @@ struct llm_build_context { n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), mla_attn (cparams.mla_attn), + fused_moe_up_gate(cparams.fused_moe_up_gate), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), @@ -17605,6 +17598,7 @@ struct llama_context_params llama_context_default_params() { /*.offload_kqv =*/ true, /*.flash_attn =*/ false, /*.mla_attn =*/ false, + /*.fused_moe_up_gate =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -17804,6 +17798,7 @@ struct llama_context * llama_new_context_with_model( cparams.offload_kqv = params.offload_kqv; cparams.flash_attn = params.flash_attn; cparams.mla_attn = params.mla_attn; + cparams.fused_moe_up_gate= params.fused_moe_up_gate; cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; @@ -17871,6 +17866,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn); + LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); From 547eee81d99a2676975a9768166b7d164473b8fa Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 24 Feb 2025 09:29:58 +0200 Subject: [PATCH 022/132] Fix #230 (#231) Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_mul_mat.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 0f7cd1e5..0955f15d 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -4062,7 +4062,7 @@ static void mul_mat_q4_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn template static void mul_mat_q5_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%8 == 0); + GGML_ASSERT(nrc_x%4 == 0); Q8 q8(info); auto m4 = _mm256_set1_epi8(0xf); auto m5 = _mm256_set1_epi8(0x10); @@ -4232,7 +4232,7 @@ static void mul_mat_q5_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn template static void mul_mat_q6_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%8 == 0); + GGML_ASSERT(nrc_x%4 == 0); Q8 q8(info); auto m4 = _mm256_set1_epi8(0xf); auto m6 = _mm256_set1_epi8(0x30); @@ -6493,7 +6493,6 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI #endif return; } - GGML_ASSERT(nrc_x%8 == 0); __m256i qx[2]; __m256i acc[2*nrc_y] = {}; float dy[nrc_y]; @@ -6566,7 +6565,7 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI template static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%8 == 0); + GGML_ASSERT(nrc_x%4 == 0); GGML_ASSERT(n%32 == 0); __m256i qx[4]; #ifndef HAVE_FANCY_SIMD From 94b659a2f106e017e5eeb6f492dc9f290e136833 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 25 Feb 2025 17:55:58 +0200 Subject: [PATCH 023/132] Give the user the option to override where model weights are stored (#232) * Give the user the option to override where model weights are stored * Fix ggml_nbytes() problem and cleanup For a tensor with zero elements ggml_nbytes() was returning uint64_t::max, and this was causing graph allocation failure. * Add timing info to CUDA graph evaluation * Add more timing info --------- Co-authored-by: Iwan Kawrakow --- common/common.cpp | 51 + common/common.h | 1 + examples/llama-bench/llama-bench.cpp | 53 ++ ggml/src/ggml-alloc.c | 2 + ggml/src/ggml-backend.c | 32 + ggml/src/ggml-cuda.cu | 11 + ggml/src/ggml.c | 10 + include/llama.h | 7 + src/llama.cpp | 1302 ++++++++++++++------------ 9 files changed, 848 insertions(+), 621 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index f975aee3..464b4710 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -265,6 +265,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.kv_overrides.emplace_back(); params.kv_overrides.back().key[0] = 0; } + if (!params.tensor_buft_overrides.empty()) { + params.tensor_buft_overrides.push_back({nullptr, nullptr}); + } return true; } @@ -287,6 +290,40 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { return true; } +namespace { +bool parse_buft_overrides(const std::string& value, std::vector& overrides) { + /* static */ std::map buft_list; + if (buft_list.empty()) { + // enumerate all the devices and add their buffer types to the list + for (size_t i = 0; i < ggml_backend_reg_get_count(); ++i) { + //auto * dev = ggml_backend_reg_get_name(i); + auto * buft = ggml_backend_reg_get_default_buffer_type(i); + if (buft) { + buft_list[ggml_backend_buft_name(buft)] = buft; + } + } + } + for (const auto & override : string_split(value, ',')) { + std::string::size_type pos = override.find('='); + if (pos == std::string::npos) { + fprintf(stderr, "Invalid buft override argument %s\n", value.c_str()); + return false; + } + std::string tensor_name = override.substr(0, pos); + std::string buffer_type = override.substr(pos + 1); + if (buft_list.find(buffer_type) == buft_list.end()) { + fprintf(stderr, "Available buffer types:\n"); + for (const auto & it : buft_list) { + fprintf(stderr, " %s\n", ggml_backend_buft_name(it.second)); + } + return false; + } + overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)}); + } + return true; +} +} + #define CHECK_ARG if (++i >= argc) { invalid_param = true; return true; } bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) { @@ -1120,6 +1157,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } return true; } + if (arg == "--override-tensor" || arg == "-ot") { + CHECK_ARG + if (!parse_buft_overrides(std::string{argv[i]}, params.tensor_buft_overrides)) { + fprintf(stderr, "error: Invalid tensor buffer type override: %s\n", argv[i]); + invalid_param = true; + } + return true; + } if (arg == "--host") { CHECK_ARG params.hostname = argv[i]; @@ -2238,6 +2283,12 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key"); mparams.kv_overrides = params.kv_overrides.data(); } + if (params.tensor_buft_overrides.empty()) { + mparams.tensor_buft_overrides = NULL; + } else { + GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern"); + mparams.tensor_buft_overrides = params.tensor_buft_overrides.data(); + } return mparams; } diff --git a/common/common.h b/common/common.h index f86a58cb..152fd1cf 100644 --- a/common/common.h +++ b/common/common.h @@ -135,6 +135,7 @@ struct gpt_params { std::vector in_files; // all input files std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector kv_overrides; + std::vector tensor_buft_overrides; bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply) std::vector lora_adapters; // lora adapter path with user defined scale diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index b0790e20..438d2a7c 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -236,6 +236,7 @@ struct cmd_params { std::vector> tensor_split; std::vector use_mmap; std::vector embeddings; + std::vector buft_overrides; ggml_numa_strategy numa; int reps; bool verbose; @@ -267,6 +268,7 @@ static const cmd_params cmd_params_defaults = { /* tensor_split */ {std::vector(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, + /* buft_overrides */ {}, /* numa */ GGML_NUMA_STRATEGY_DISABLED, /* reps */ 5, /* verbose */ false, @@ -309,6 +311,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0"); printf(" -w, --warmup <0|1> (default: %s)\n", cmd_params_defaults.warmup ? "1" : "0"); printf(" -rtr, --run-time-repack <0|1> (default: %s)\n", cmd_params_defaults.repack ? "1" : "0"); + printf(" -ot, --override-tensor pattern (default: none)\n"); printf(" -fmoe, --fused-moe <0|1> (default: %s)\n", cmd_params_defaults.fmoe? "1" : "0"); printf("\n"); printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n"); @@ -349,6 +352,39 @@ static ggml_type ggml_type_from_name(const std::string & s) { return GGML_TYPE_COUNT; } +namespace { +bool parse_buft_overrides(const std::string& value, std::vector& overrides) { + /* static */ std::map buft_list; + if (buft_list.empty()) { + // enumerate all the devices and add their buffer types to the list + for (size_t i = 0; i < ggml_backend_reg_get_count(); ++i) { + //auto * dev = ggml_backend_reg_get_name(i); + auto * buft = ggml_backend_reg_get_default_buffer_type(i); + if (buft) { + buft_list[ggml_backend_buft_name(buft)] = buft; + } + } + } + for (const auto & override : string_split(value, ',')) { + std::string::size_type pos = override.find('='); + if (pos == std::string::npos) { + fprintf(stderr, "Invalid buft override argument %s\n", value.c_str()); + return false; + } + std::string tensor_name = override.substr(0, pos); + std::string buffer_type = override.substr(pos + 1); + if (buft_list.find(buffer_type) == buft_list.end()) { + fprintf(stderr, "Available buffer types:\n"); + for (const auto & it : buft_list) { + fprintf(stderr, " %s\n", ggml_backend_buft_name(it.second)); + } + return false; + } + overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)}); + } + return true; +} +} static cmd_params parse_cmd_params(int argc, char ** argv) { cmd_params params; @@ -616,6 +652,16 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } params.fmoe = std::stoi(argv[i]); + } else if (arg == "-ot" || arg == "--override-tensor") { + if (++i >= argc) { + invalid_param = true; + break; + } + if (!parse_buft_overrides(std::string{argv[i]}, params.buft_overrides)) { + fprintf(stderr, "error: Invalid tensor buffer type override: %s\n", argv[i]); + invalid_param = true; + break; + } } else { invalid_param = true; break; @@ -648,6 +694,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; } if (params.n_threads.empty()) { params.n_threads = cmd_params_defaults.n_threads; } + if (!params.buft_overrides.empty()) params.buft_overrides.emplace_back(llama_model_tensor_buft_override{nullptr, nullptr}); return params; } @@ -685,6 +732,7 @@ struct cmd_params_instance { bool embeddings; bool repack = false; bool fmoe = false; + const llama_model_tensor_buft_override* buft_overrides; llama_model_params to_llama_mparams() const { llama_model_params mparams = llama_model_default_params(); @@ -698,6 +746,7 @@ struct cmd_params_instance { mparams.tensor_split = tensor_split.data(); mparams.use_mmap = use_mmap; mparams.repack_tensors = repack; + mparams.tensor_buft_overrides = buft_overrides; return mparams; } @@ -777,6 +826,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .embeddings = */ embd, /* .repack = */ params.repack, /* .fmoe = */ params.fmoe, + /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); } @@ -807,6 +857,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .embeddings = */ embd, /* .repack = */ params.repack, /* .fmoe = */ params.fmoe, + /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); } @@ -837,6 +888,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .embeddings = */ embd, /* .repack = */ params.repack, /* .fmoe = */ params.fmoe, + /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); } @@ -867,6 +919,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .embeddings = */ embd, /* .repack = */ params.repack, /* .fmoe = */ params.fmoe, + /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); } diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index e485326a..d811dee6 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -174,6 +174,8 @@ static size_t ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t siz // this should never happen fprintf(stderr, "%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n", __func__, size, max_avail); + fprintf(stderr, "%s: tensor was %s with %zu elements and %zu bytes\n", __func__, tensor->name, + ggml_nelements(tensor), ggml_nbytes(tensor)); GGML_ABORT("not enough space in the buffer"); } } diff --git a/ggml/src/ggml-backend.c b/ggml/src/ggml-backend.c index e1651cc6..0458bd0c 100644 --- a/ggml/src/ggml-backend.c +++ b/ggml/src/ggml-backend.c @@ -9,6 +9,7 @@ #include #include +#define IK_PRINT_TIMING 0 #define MAX(a, b) ((a) > (b) ? (a) : (b)) @@ -229,7 +230,17 @@ GGML_CALL void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * return; } + +#if IK_PRINT_TIMING + int64_t tim1 = ggml_time_us(); +#endif buf->iface.set_tensor(buf, tensor, data, offset, size); +#if IK_PRINT_TIMING + int64_t tim2 = ggml_time_us(); + //printf("%s(%s) %zu %d us\n", __func__, tensor->name, size, (int)(tim2-tim1)); + printf("%s(%s): %d us\n", __func__, tensor->name, (int)(tim2-tim1)); +#endif + } GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { @@ -243,7 +254,15 @@ GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * return; } +#if IK_PRINT_TIMING + int64_t tim1 = ggml_time_us(); +#endif buf->iface.get_tensor(buf, tensor, data, offset, size); +#if IK_PRINT_TIMING + int64_t tim2 = ggml_time_us(); + //printf("%s(%s) %zu %d us\n", __func__, tensor->name, size, (int)(tim2-tim1)); + printf("%s(%s): %d us\n", __func__, tensor->name, (int)(tim2-tim1)); +#endif } void ggml_backend_synchronize(ggml_backend_t backend) { @@ -1751,7 +1770,11 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) { static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) { struct ggml_backend_sched_split * splits = sched->splits; + for (int i = 0; i < sched->n_splits; i++) { +#if IK_PRINT_TIMING + int64_t tim1 = ggml_time_us(); +#endif struct ggml_backend_sched_split * split = &splits[i]; int split_backend_id = split->backend_id; ggml_backend_t split_backend = sched->backends[split_backend_id]; @@ -1792,6 +1815,10 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } if (!sched->callback_eval) { +#if IK_PRINT_TIMING + int64_t tim2 = ggml_time_us(); + printf("%s(.1.): %d us\n", __func__, (int)(tim2-tim1)); +#endif enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph); if (ec != GGML_STATUS_SUCCESS) { return ec; @@ -1814,6 +1841,11 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1); +#if IK_PRINT_TIMING + int64_t tim2 = ggml_time_us(); + printf("%s(.2.): %d us\n", __func__, (int)(tim2-tim1)); +#endif + enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &gv); if (ec != GGML_STATUS_SUCCESS) { return ec; diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 26d06d56..c305cd89 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -50,6 +50,8 @@ #include #include +#define IK_PRINT_TIMING 0 + static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); static void ggml_cuda_default_log_callback(enum ggml_log_level level, const char * msg, void * user_data) { @@ -2446,6 +2448,10 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device); } +#if IK_PRINT_TIMING + int64_t tim1 = ggml_time_us(); +#endif + switch (dst->op) { case GGML_OP_REPEAT: ggml_cuda_op_repeat(ctx, dst); @@ -2618,6 +2624,11 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg CUDA_CHECK(err); } +#if IK_PRINT_TIMING + int64_t tim2 = ggml_time_us(); + printf("%s(%s): %d us\n", ggml_op_name(dst->op), dst->name, (int)(tim2 - tim1)); +#endif + return true; } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 8efe2653..80dd25ff 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4267,6 +4267,9 @@ GGML_CALL int64_t ggml_blck_size(enum ggml_type type) { } GGML_CALL size_t ggml_nbytes(const struct ggml_tensor * tensor) { + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + if (tensor->ne[i] <= 0) return 0; + } size_t nbytes; size_t blck_size = ggml_blck_size(tensor->type); if (blck_size == 1) { @@ -21480,6 +21483,9 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl #ifdef GGML_USE_OPENMP if (n_threads > 1) { +//#if IK_PRINT_TIMING +// int64_t tim1 = ggml_time_us(); +//#endif #pragma omp parallel num_threads(n_threads) { #pragma omp single @@ -21496,6 +21502,10 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl }; ggml_graph_compute_thread(&worker); } +//#if IK_PRINT_TIMING +// int64_t tim2 = ggml_time_us(); +// printf("%s(...): %d us\n", __func__, (int)(tim2-tim1)); +//#endif } else { struct ggml_compute_state worker = { .thrd = 0, diff --git a/include/llama.h b/include/llama.h index 23e32642..beb6ecba 100644 --- a/include/llama.h +++ b/include/llama.h @@ -305,6 +305,11 @@ extern "C" { }; }; + struct llama_model_tensor_buft_override { + const char * pattern; + ggml_backend_buffer_type_t buft; + }; + struct llama_model_params { int32_t n_gpu_layers; // number of layers to store in VRAM enum llama_split_mode split_mode; // how to split the model across multiple GPUs @@ -332,6 +337,8 @@ extern "C" { // override key-value pairs of the model meta data const struct llama_model_kv_override * kv_overrides; + const struct llama_model_tensor_buft_override * tensor_buft_overrides; + // Keep the booleans together to avoid misalignment during copy-by-value. bool vocab_only; // only load the vocabulary, no weights bool use_mmap; // use mmap if possible diff --git a/src/llama.cpp b/src/llama.cpp index eed7aa61..0276b69c 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -12,6 +12,8 @@ // TODO: fix this include #include "iqk/iqk_quantize.h" +#define IK_PRINT_TIMING 0 + #ifdef GGML_USE_RPC # include "ggml-rpc.h" #endif @@ -99,6 +101,7 @@ #include #include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -3855,6 +3858,7 @@ struct llama_model_loader { std::vector weights; std::unordered_map kv_overrides; + const llama_model_tensor_buft_override * tensor_buft_overrides; struct gguf_context * meta = NULL; std::vector contexts; @@ -3862,7 +3866,9 @@ struct llama_model_loader { std::string arch_name; LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); - llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, const struct llama_model_kv_override * param_overrides_p) { + llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, + const llama_model_kv_override * param_overrides_p, + const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) { int trace = 0; if (getenv("LLAMA_TRACE")) { trace = atoi(getenv("LLAMA_TRACE")); @@ -3874,6 +3880,8 @@ struct llama_model_loader { } } + tensor_buft_overrides = param_tensor_buft_overrides_p; + struct ggml_context * ctx = NULL; struct gguf_init_params params = { /*.no_alloc = */ true, @@ -6414,6 +6422,26 @@ static bool llm_load_tensors( model.ctxs.push_back(ctx); } + auto ctx_for_buft = [&model, &ctx_map, ctx_size](ggml_backend_buffer_type_t buft) -> ggml_context * { + if (auto it = ctx_map.find(buft); it != ctx_map.end()) return it->second; + + ggml_init_params params = { + /*.mem_size =*/ ctx_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + throw std::runtime_error(format("failed to create ggml context")); + } + + ctx_map[buft] = ctx; + model.ctxs.emplace_back(ctx); + + return ctx; + }; + LLAMA_LOG_INFO("%s: ggml ctx size = %7.2f MiB\n", __func__, model.ctxs.size()*ctx_size/1024.0/1024.0); // create tensors for the weights @@ -6447,6 +6475,20 @@ static bool llm_load_tensors( model.layers.resize(n_layer); + auto create_tensor = [&ml, &ctx_map, &ctx_for_buft] (ggml_context * ctx, const std::string & name, const std::vector & ne, int flags = 0) { + if (ml.tensor_buft_overrides) { + for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) { + std::regex pattern(overrides->pattern); + if (std::regex_search(name, pattern)) { + LLAMA_LOG_INFO("Tensor %s buffer type overriden to %s\n", name.c_str(), ggml_backend_buft_name(overrides->buft)); + ctx = ctx_for_buft(overrides->buft); + break; + } + } + } + return ml.create_tensor(ctx, name, ne, flags); + }; + const auto tn = LLM_TN(model.arch); switch (model.arch) { case LLM_ARCH_LLAMA: @@ -6455,16 +6497,16 @@ static bool llm_load_tensors( case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -6474,39 +6516,39 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); // optional bias tensors - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.rope_freqs = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); if (n_expert == 0) { - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); // optional MLP bias - layer.ffn_gate_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_down_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_up_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_gate_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); } else { - layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); - layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); if (layer.ffn_gate_exps) { - layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}); - layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}); + layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); } else { // merge split expert into a single tensor for compatibility with older models // requires disabling mmap @@ -6540,16 +6582,16 @@ static bool llm_load_tensors( throw std::runtime_error("Grok model cannot have zero experts"); } - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -6559,23 +6601,23 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); + layer.attn_out_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); - layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); if (layer.ffn_gate_exps) { - layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}); - layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}); + layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); } else { // merge split expert into a single tensor for compatibility with older models // requires disabling mmap @@ -6601,7 +6643,7 @@ static bool llm_load_tensors( } } - layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); + layer.layer_out_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); } } break; case LLM_ARCH_DBRX: @@ -6610,12 +6652,12 @@ static bool llm_load_tensors( throw std::runtime_error("DBRX model cannot have zero experts"); } - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -6624,25 +6666,25 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); + layer.attn_out_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); - layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); - layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}); - layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}); - layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}); + layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); } } break; case LLM_ARCH_BAICHUAN: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -6651,32 +6693,32 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_FALCON: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); if (!model.output) { - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU } } @@ -6686,32 +6728,32 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_norm_2 = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_norm_2_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_STARCODER: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.pos_embd = create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); if (!model.output) { // needs to be on GPU - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -6722,37 +6764,37 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.bqkv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); } } break; case LLM_ARCH_BERT: case LLM_ARCH_NOMIC_BERT: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.type_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); if (model.arch == LLM_ARCH_BERT) { - model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}); + model.pos_embd = create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}); } - model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); - model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); + model.tok_norm = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); + model.tok_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -6761,45 +6803,45 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; if (model.arch == LLM_ARCH_BERT) { - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); } else { - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); } - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); - layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}); + layer.attn_out_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); + layer.attn_out_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); if (model.arch == LLM_ARCH_BERT) { - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + layer.ffn_up_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); } else { - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); } - layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); - layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}); + layer.layer_out_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); + layer.layer_out_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}); } } break; case LLM_ARCH_JINA_BERT_V2: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // word_embeddings - model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); // token_type_embeddings + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // word_embeddings + model.type_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); // token_type_embeddings - model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); // LayerNorm - model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); //LayerNorm bias + model.tok_norm = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); // LayerNorm + model.tok_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); //LayerNorm bias for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -6807,51 +6849,51 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; // JinaBertLayer - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); - layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_q_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); - layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); //output_dens - layer.bo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); //output_dens + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); //output_dens + layer.bo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); //output_dens - layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); //output_norm - layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}); + layer.attn_out_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); //output_norm + layer.attn_out_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}); - layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_norm_2 = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_norm_2_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); - layer.layer_out_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); - layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}); + layer.layer_out_norm = create_tensor(ctx_split, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); + layer.layer_out_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}); } } break; case LLM_ARCH_BLOOM: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - model.tok_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); - model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_norm = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); + model.tok_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -6860,38 +6902,38 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.bqkv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); } } break; case LLM_ARCH_MPT: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.pos_embd = create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, llama_model_loader::TENSOR_NOT_REQUIRED); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); if (!model.output) { - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU } } @@ -6901,43 +6943,43 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.bqkv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_q_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); // AWQ ScaleActivation layer - layer.ffn_act = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_act = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); } } break; case LLM_ARCH_STABLELM: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -6946,40 +6988,40 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); // optional bias tensors, present in Stable LM 2 1.6B - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); // optional q and k layernorms, present in StableLM 2 12B - layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED); // optional FFN norm, not present in StableLM 2 12B which uses parallel residual - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_QWEN: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -6988,30 +7030,30 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}); + layer.bqkv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff/2}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff/2, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff/2}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff/2}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff/2, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff/2}); } } break; case LLM_ARCH_QWEN2: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -7021,33 +7063,33 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); // optional bias tensors - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); - layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_QWEN2MOE: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -7056,21 +7098,21 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); // optional bias tensors - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); - layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); GGML_ASSERT(n_expert > 0); GGML_ASSERT(n_expert_used > 0); @@ -7078,29 +7120,29 @@ static bool llm_load_tensors( // MoE branch const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); - layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}); - layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}); + layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); // Shared expert branch const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - layer.ffn_gate_inp_shexp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}); - layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}); - layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}); - layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}); + layer.ffn_gate_inp_shexp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}); + layer.ffn_gate_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}); + layer.ffn_down_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}); + layer.ffn_up_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}); } } break; case LLM_ARCH_PHI2: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); - model.output_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -7109,43 +7151,43 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); if (layer.wqkv == nullptr) { - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); } - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); } } break; case LLM_ARCH_PHI3: { const int64_t n_embd_head = n_embd / n_head; - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }); } for (int i = 0; i < n_layer; ++i) { @@ -7154,28 +7196,28 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }); - layer.rope_long = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight"), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); - layer.rope_short = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.rope_long = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight"), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); } } break; case LLM_ARCH_PLAMO: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -7184,28 +7226,28 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_GPT2: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.pos_embd = create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -7214,34 +7256,34 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.bqkv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); } } break; case LLM_ARCH_CODESHELL: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -7250,32 +7292,32 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.bqkv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); } } break; case LLM_ARCH_ORION: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -7283,30 +7325,30 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_INTERNLM2: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -7315,26 +7357,26 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - // layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + // layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_GEMMA: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -7342,26 +7384,26 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); } } break; case LLM_ARCH_GEMMA2: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading for (int i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -7369,34 +7411,34 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); - layer.attn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); + layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}); } } break; case LLM_ARCH_STARCODER2: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -7407,29 +7449,29 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); // optional bias tensors - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); - layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); // optional bias tensors - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP , "bias", i), { n_ff}); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + layer.ffn_up_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP , "bias", i), { n_ff}); } } break; case LLM_ARCH_MAMBA: @@ -7442,16 +7484,16 @@ static bool llm_load_tensors( // only an expansion factor of 2 is supported for now GGML_ASSERT(2 * n_embd == d_inner); - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed, duplicated to allow offloading if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -7462,32 +7504,32 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; // norm - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}); + layer.ssm_in = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}); - layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}); - layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}); + layer.ssm_conv1d = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}); + layer.ssm_conv1d_b = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}); - layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}); + layer.ssm_x = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}); - layer.ssm_dt = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}); - layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}); + layer.ssm_dt = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}); + layer.ssm_dt_b = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}); // no "weight" suffix for these - layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}); - layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner}); + layer.ssm_a = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}); + layer.ssm_d = create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner}); // out_proj - layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); + layer.ssm_out = create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); } } break; case LLM_ARCH_XVERSE: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -7496,28 +7538,28 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_COMMAND_R: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); // init output from the input tok embed - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } for (int i = 0; i < n_layer; ++i) { @@ -7526,33 +7568,33 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); if (n_layer >= 64){ - layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}); - layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}); + layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}); + layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}); } - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_OLMO: // adapted from LLM_ARCH_LLAMA with norm params removed { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -7561,25 +7603,25 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_OPENELM: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); // init output from the input tok embed - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } for (int i = 0; i < n_layer; ++i) { @@ -7592,28 +7634,28 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k}); - layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}); - layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k}); + layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}); + layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_GPTNEOX: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -7622,37 +7664,37 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.bqkv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); } } break; case LLM_ARCH_ARCTIC: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -7662,24 +7704,24 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_embd}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_embd}); - layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); - layer.ffn_norm_exps = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd}); - layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false); - layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}); - layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + layer.ffn_norm_exps = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd}); + layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false); + layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}); + layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); } } break; case LLM_ARCH_DEEPSEEK2: @@ -7695,12 +7737,12 @@ static bool llm_load_tensors( const int64_t n_ff_exp = hparams.n_ff_exp; const int64_t n_expert_shared = hparams.n_expert_shared; - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -7709,59 +7751,59 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); if (!is_lite) { - layer.attn_q_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}); + layer.attn_q_a_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}); } - layer.attn_kv_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}); + layer.attn_kv_a_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}); if (!is_lite) { - layer.wq_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}); - layer.wq_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}); + layer.wq_a = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}); + layer.wq_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}); } else { - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); } - layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}); - layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}); - layer.wk_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 1); - layer.wv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v}, 1); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}); + layer.wkv_a_mqa = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}); + layer.wkv_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}); + layer.wk_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 1); + layer.wv_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v}, 1); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); if (i < (int) hparams.n_layer_dense_lead) { - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } else { - layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); - layer.ffn_exp_probs_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 1); + layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + layer.ffn_exp_probs_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 1); GGML_ASSERT(n_expert > 0); GGML_ASSERT(n_expert_used > 0); // MoE branch - layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); - layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}); - layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}); + layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); // Shared expert branch - layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}); - layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}); - layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}); + layer.ffn_gate_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}); + layer.ffn_down_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}); + layer.ffn_up_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}); } } } break; case LLM_ARCH_BITNET: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading } const uint32_t n_ff = hparams.n_ff(); @@ -7772,44 +7814,44 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_sub_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wq_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wk_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wv_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.wo_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wq_scale = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wk_scale = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wv_scale = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wo_scale = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_sub_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_gate_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_gate_scale = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_scale = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_scale = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); } } break; case LLM_ARCH_T5: { const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts; - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm_enc = ml.create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_enc = create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -7819,55 +7861,55 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_rel_b_enc = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_norm_enc = create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_rel_b_enc = create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wq_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wk_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wv_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); - layer.wo_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}); + layer.wq_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wk_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}); - layer.ffn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_down_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_norm_enc = create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_gate_enc = create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_rel_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_rel_b = create_tensor(ctx_input, tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}); + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}); - layer.attn_norm_cross = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_cross = create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM, "weight", i), {n_embd}); // this tensor seems to be unused in HF transformers implementation - layer.attn_rel_b_cross = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_rel_b_cross = create_tensor(ctx_input, tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wq_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wk_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wv_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); - layer.wo_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}); + layer.wq_cross = create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wk_cross = create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv_cross = create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo_cross = create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_gate = create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_T5ENCODER: { const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts; - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm_enc = ml.create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output_norm_enc = create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -7877,29 +7919,29 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_rel_b_enc = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_norm_enc = create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_rel_b_enc = create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.wq_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wk_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wv_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); - layer.wo_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}); + layer.wq_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wk_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}); - layer.ffn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_down_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_norm_enc = create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_gate_enc = create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up_enc = create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; case LLM_ARCH_JAIS: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // Output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -7908,36 +7950,36 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.bqkv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_gate_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); } } break; case LLM_ARCH_CHATGLM: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); } for (int i = 0; i < n_layer; ++i) { @@ -7946,18 +7988,18 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + (hparams.n_embd_head_k << 2)}); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + (hparams.n_embd_head_k << 2)}); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + (hparams.n_embd_head_k << 2)}); + layer.bqkv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + (hparams.n_embd_head_k << 2)}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); } } break; default: @@ -8159,7 +8201,8 @@ static bool llm_load_tensors( // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) { try { - llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.repack_tensors, params.kv_overrides); + llama_model_loader ml(fname, params.use_mmap, params.check_tensors, + params.repack_tensors, params.kv_overrides, params.tensor_buft_overrides); model.hparams.vocab_only = params.vocab_only; @@ -14477,6 +14520,10 @@ static struct ggml_cgraph * llama_build_graph( bool worst_case) { const auto & model = lctx.model; +#if IK_PRINT_TIMING + auto tim1 = ggml_time_us(); +#endif + // this callback allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.) llm_build_cb cb = [&](struct ggml_tensor * cur, const char * name, int il) { if (il >= 0) { @@ -14690,6 +14737,11 @@ static struct ggml_cgraph * llama_build_graph( llm.free(); +#if IK_PRINT_TIMING + auto tim2 = ggml_time_us(); + printf("%s(...): %d us\n", __func__, int(tim2-tim1)); +#endif + return result; } @@ -14746,6 +14798,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { // set input data // +#if IK_PRINT_TIMING + auto tim1 = ggml_time_us(); +#endif const auto & hparams = lctx.model.hparams; const auto & cparams = lctx.cparams; const auto & kv_self = lctx.kv_self; @@ -15105,6 +15160,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } } +#if IK_PRINT_TIMING + auto tim2 = ggml_time_us(); + printf("%s(...): %d us\n", __func__, int(tim2-tim1)); +#endif } // Make sure enough space is available for outputs. @@ -16839,7 +16898,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s auto v = (std::vector*)params->kv_overrides; kv_overrides = v->data(); } - llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, /* repack_tensors */ false, kv_overrides); + llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, /* repack_tensors */ false, kv_overrides, nullptr); ml.init_mappings(false); // no prefetching llama_model model; @@ -17554,6 +17613,7 @@ struct llama_model_params llama_model_default_params() { /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, /*.kv_overrides =*/ nullptr, + /*.tensor_buft_overrides =*/ nullptr, /*.vocab_only =*/ false, /*.use_mmap =*/ true, /*.use_mlock =*/ false, From 51029edfdf286df76f9268fc87b9514291b2fe42 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 27 Feb 2025 08:42:18 +0200 Subject: [PATCH 024/132] Faster MLA on CUDA (#234) * Slight MLA TG performance improvement on CUDA The low MLA performance on CUDA is dues to the wk_b * q_nope operation. It turns into n_head matrix multiplications with n_head separate quantization and GEMV steps. The associated overhead is just too much for TG where each GEMV is very fast (512 x 128 = 131 KFLOP for DeepSeek-Lite, 4X that for DeepSeekV3/R1). The way it was done there was also a copy of each q_nope row before quantization, which I have now eliminated. This results in a ~2.5% speedup. What needs to happen instead is to launch a single computation that quantizes all heads, and then have a kernel that does the GEMV for all heads instead of n_head sequential GEMVs. * Slightly better * CUDA: Quantize non-contiguous tensors * Much better MLA It is a total hack, but it works. * Cleanup Remove duplicated gemv's. --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda.cu | 36 ++- ggml/src/ggml-cuda/iqk_mmvq.cu | 93 +++++--- ggml/src/ggml-cuda/iqk_mmvq.cuh | 30 ++- ggml/src/ggml-cuda/mmvq.cu | 403 ++++++++++++++++++++------------ ggml/src/ggml-cuda/mmvq.cuh | 8 +- ggml/src/ggml-cuda/quantize.cu | 52 +++++ ggml/src/ggml-cuda/quantize.cuh | 3 + src/llama.cpp | 3 +- 8 files changed, 420 insertions(+), 208 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index c305cd89..bc960678 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -1518,6 +1518,8 @@ static void ggml_cuda_op_mul_mat( } } + bool quantization_done = false; + for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { if ((!split && id != ctx.device) || dev[id].row_low == dev[id].row_high) { continue; @@ -1561,9 +1563,15 @@ static void ggml_cuda_op_mul_mat( } dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size); - if (src1_on_device && src1_is_contiguous) { - quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream); + if (src1_on_device && (src1_is_contiguous || (src1->ne[1] == 1 && src1->ne[3] == 1 && src1->nb[0] == sizeof(float)))) { + if (src1_is_contiguous) { + quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream); + } else { + //printf("Calling quantize_tensor_q8_1_cuda for %s\n", src0->name); + quantize_tensor_q8_1_cuda(src1, dev[id].src1_ddq, src0->type, stream); + } CUDA_CHECK(cudaGetLastError()); + quantization_done = true; } } @@ -1583,6 +1591,20 @@ static void ggml_cuda_op_mul_mat( } const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11; + if (!(split && used_devices > 1) && quantization_done && ne11 == 1 && ne12 > 1 && ne13 == 1) { + //printf("invoking fast path for %s x %s\n", src0->name, src1->name); + int id = ctx.device; + char * src0_dd_i = dev[id].src0_dd; + float * src1_ddf_i = dev[id].src1_ddf; + char * src1_ddq_i = dev[id].src1_ddq; + float * dst_dd_i = dev[id].dst_dd; + cudaStream_t stream = ctx.stream(id, 0); + ggml_cuda_op_mul_mat_vec_q_3D(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i, + dev[id].row_low, dev[id].row_high, ne11, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + return; + } + for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) { const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_CUDA_MAX_STREAMS : 0; const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride; @@ -1649,13 +1671,17 @@ static void ggml_cuda_op_mul_mat( } } } else if (src1_on_device && !src1_is_contiguous) { - CUDA_CHECK(ggml_cuda_cpy_tensor_2d( - src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream)); + if (!quantization_done) { + //printf("Copying %s\n", src1->name); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d( + src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream)); + } } else { GGML_ABORT("fatal error"); } - if (quantize_src1 && !src1_is_contiguous) { + if (quantize_src1 && !src1_is_contiguous && !quantization_done) { + //printf("Quantizing %s\n", src1->name); quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream); CUDA_CHECK(cudaGetLastError()); } diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index 36dbb52a..2eafe463 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -15,11 +15,7 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_ namespace { template -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -// tell the compiler to use as many registers as it wants, see nwarps definition below -__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1) -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -__global__ void iqk_mul_mat_vec_q( +__device__ void iqk_mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size) { @@ -94,10 +90,27 @@ __global__ void iqk_mul_mat_vec_q( } } +template +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +// tell the compiler to use as many registers as it wants, see nwarps definition below +__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1) +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +__global__ void iqk_mul_mat_vec_q( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size, + const uint64_t nb02, const uint64_t nb12, const uint64_t nb2) { + int i2 = blockIdx.y; + const char * cx = (const char *)vx + i2*nb02; + const char * cy = (const char *)vy + i2*nb12; + char * cdst = (char *)dst + i2*nb2; + iqk_mul_mat_vec_q(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size); +} + template void iqk_mul_mat_vec_q_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); //GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE); @@ -132,35 +145,35 @@ void iqk_mul_mat_vec_q_cuda( } } const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block; - const dim3 block_nums(nblocks, 1, 1); + const dim3 block_nums(nblocks, ne2, 1); const dim3 block_dims(WARP_SIZE, nwarps, 1); const int64_t row_size = ggml_row_size(type, ncols_x); switch (ncols_y) { case 1: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); break; case 2: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); break; case 3: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); break; case 4: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); break; case 5: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); break; case 6: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); break; case 7: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); break; case 8: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); break; default: GGML_ASSERT(false); @@ -730,68 +743,78 @@ __device__ __forceinline__ float vec_dot_iq2_bn_q8_1( void mul_mat_vec_iq2_k_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } void mul_mat_vec_iq3_k_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } void mul_mat_vec_iq4_k_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } void mul_mat_vec_iq4_ks_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } void mul_mat_vec_iq4_kss_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } void mul_mat_vec_iq2_ks_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } void mul_mat_vec_iq5_k_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } void mul_mat_vec_iq6_k_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } void mul_mat_vec_iq1_bn_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } void mul_mat_vec_iq2_bn_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index 1693a73a..a128bc53 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -2,40 +2,50 @@ void mul_mat_vec_iq2_k_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); void mul_mat_vec_iq3_k_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); void mul_mat_vec_iq4_k_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); void mul_mat_vec_iq5_k_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); void mul_mat_vec_iq6_k_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); void mul_mat_vec_iq4_ks_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); void mul_mat_vec_iq4_kss_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); void mul_mat_vec_iq2_ks_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); void mul_mat_vec_iq1_bn_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); void mul_mat_vec_iq2_bn_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream); + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index cdf13533..2a520d68 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -1,6 +1,6 @@ #include "mmvq.cuh" -#include "vecdotq.cuh" #include "iqk_mmvq.cuh" +#include "vecdotq.cuh" typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); @@ -51,11 +51,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) { } template -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -// tell the compiler to use as many registers as it wants, see nwarps definition below -__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1) -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -static __global__ void mul_mat_vec_q( +static __device__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -132,10 +128,27 @@ static __global__ void mul_mat_vec_q( } } +template +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +// tell the compiler to use as many registers as it wants, see nwarps definition below +__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1) +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void mul_mat_vec_q( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, + const uint64_t nb02, const uint64_t nb12, const uint64_t nb2) { + int i2 = blockIdx.y; + const char * cx = (const char *)vx + i2*nb02; + const char * cy = (const char *)vy + i2*nb12; + char * cdst = (char *)dst + i2*nb2; + mul_mat_vec_q(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst); +} + template static void mul_mat_vec_q_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE); @@ -145,6 +158,11 @@ static void mul_mat_vec_q_cuda( int64_t nwarps = 1; int64_t rows_per_cuda_block = 1; + //if (ne2 > 1) { + // printf("%s: ncols_x = %d, nrows_x = %d, nrows_y = %d, ncols_y = %d nrows_dst = %d, ne2 = %d nb02 = %zu, nb12 = %zu, nb2 = %zu\n", __func__, + // ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2); + //} + if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2 switch(ncols_y) { case 1: @@ -170,33 +188,33 @@ static void mul_mat_vec_q_cuda( } } const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block; - const dim3 block_nums(nblocks, 1, 1); + const dim3 block_nums(nblocks, ne2, 1); const dim3 block_dims(WARP_SIZE, nwarps, 1); switch (ncols_y) { case 1: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); break; case 2: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); break; case 3: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); break; case 4: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); break; case 5: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); break; case 6: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); break; case 7: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); break; case 8: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); break; default: GGML_ABORT("fatal error"); @@ -206,142 +224,311 @@ static void mul_mat_vec_q_cuda( static void mul_mat_vec_q4_0_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } static void mul_mat_vec_q4_1_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } static void mul_mat_vec_q5_0_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } static void mul_mat_vec_q5_1_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } static void mul_mat_vec_q6_0_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } static void mul_mat_vec_q8_0_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } static void mul_mat_vec_q2_K_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } static void mul_mat_vec_q3_K_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } static void mul_mat_vec_q4_K_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } static void mul_mat_vec_q5_K_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } static void mul_mat_vec_q6_K_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } static void mul_mat_vec_iq2_xxs_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } static void mul_mat_vec_iq2_xs_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } static void mul_mat_vec_iq2_s_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } static void mul_mat_vec_iq3_xxs_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } static void mul_mat_vec_iq1_s_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } static void mul_mat_vec_iq1_m_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } static void mul_mat_vec_iq4_nl_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } static void mul_mat_vec_iq4_xs_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } static void mul_mat_vec_iq3_s_q8_1_cuda( const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); +} + +namespace { +void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggml_type type, + const int64_t ne00, const int64_t ne10, const int64_t ne0, const int64_t ne2, + const int64_t nb02, const int64_t nb12, const int64_t nb2, + const char * src0_dd_i, const char * src1_ddq_i, float * dst_dd_i, + const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + + const int64_t row_diff = row_high - row_low; + + int id = ggml_cuda_get_device(); + + // the main device has a larger memory buffer to hold the results from all GPUs + // nrows_dst == nrows of the matrix that the kernel writes into + const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; + const int64_t src1_row_size = ggml_row_size(GGML_TYPE_Q8_1, src1_padded_row_size); + + switch (type) { + case GGML_TYPE_Q4_0: + mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_Q4_1: + mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_Q5_0: + mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_Q5_1: + mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_Q6_0: + mul_mat_vec_q6_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_Q8_0: + mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_Q2_K: + mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_Q3_K: + mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_Q4_K: + mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_Q5_K: + mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_Q6_K: + mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_IQ2_XXS: + mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_IQ2_XS: + mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_IQ2_S: + mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_IQ3_XXS: + mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_IQ1_S: + mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_IQ1_M: + mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_IQ1_BN: + mul_mat_vec_iq1_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_IQ2_BN: + mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_IQ4_NL: + mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_IQ4_XS: + mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_IQ2_K: + mul_mat_vec_iq2_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_IQ3_K: + mul_mat_vec_iq3_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_IQ4_K: + mul_mat_vec_iq4_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_IQ4_KS: + mul_mat_vec_iq4_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_IQ4_KSS: + mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_IQ2_KS: + mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_IQ5_K: + mul_mat_vec_iq5_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_IQ6_K: + mul_mat_vec_iq6_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + case GGML_TYPE_IQ3_S: + mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + break; + default: + GGML_ABORT("fatal error"); + break; + } + +} +} + +void ggml_cuda_op_mul_mat_vec_q_3D( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + + const int64_t ne00 = src0->ne[0]; + const int64_t row_diff = row_high - row_low; + + const int64_t ne10 = src1->ne[0]; + GGML_ASSERT(ne10 % QK8_1 == 0); + GGML_ASSERT(src0->ne[3] == 1 && src1->ne[3] == 1 && dst->ne[3] == 1); + GGML_ASSERT(src0->ne[2] == src1->ne[2] && src0->ne[2] == dst->ne[2]); + + const int64_t ne0 = dst->ne[0]; + + int id = ggml_cuda_get_device(); + + // the main device has a larger memory buffer to hold the results from all GPUs + // nrows_dst == nrows of the matrix that the kernel writes into + const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; + const int64_t src1_row_size = ggml_row_size(GGML_TYPE_Q8_1, src1_padded_row_size); + + ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type, + ne00, ne10, ne0, dst->ne[2], + src0->nb[2], src1_row_size, dst->nb[2], + src0_dd_i, src1_ddq_i, dst_dd_i, + row_low, row_high, src1_ncols, + src1_padded_row_size, stream); + + GGML_UNUSED(src1_ddf_i); } void ggml_cuda_op_mul_mat_vec_q( @@ -364,105 +551,11 @@ void ggml_cuda_op_mul_mat_vec_q( // nrows_dst == nrows of the matrix that the kernel writes into const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; - switch (src0->type) { - case GGML_TYPE_Q4_0: - mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q4_1: - mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q5_0: - mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q5_1: - mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q6_0: - mul_mat_vec_q6_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q8_0: - mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q2_K: - mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q3_K: - mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q4_K: - mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q5_K: - mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q6_K: - mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ2_XXS: - mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ2_XS: - mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ2_S: - mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ3_XXS: - mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ1_S: - mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ1_M: - mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ1_BN: - mul_mat_vec_iq1_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ2_BN: - mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ4_NL: - mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ4_XS: - mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ2_K: - mul_mat_vec_iq2_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ3_K: - mul_mat_vec_iq3_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ4_K: - mul_mat_vec_iq4_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ4_KS: - mul_mat_vec_iq4_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ4_KSS: - mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ2_KS: - mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ5_K: - mul_mat_vec_iq5_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ6_K: - mul_mat_vec_iq6_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ3_S: - mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - default: - GGML_ABORT("fatal error"); - break; - } + ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type, + ne00, ne10, ne0, 1, 0, 0, 0, + src0_dd_i, src1_ddq_i, dst_dd_i, + row_low, row_high, src1_ncols, + src1_padded_row_size, stream); - GGML_UNUSED(src1); - GGML_UNUSED(dst); GGML_UNUSED(src1_ddf_i); - GGML_UNUSED(src1_ncols); - GGML_UNUSED(src1_padded_row_size); } diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index d9e42fdd..f86aefe2 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -2,8 +2,12 @@ #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. -void ggml_cuda_op_mul_mat_vec_q( - ggml_backend_cuda_context & ctx, +void ggml_cuda_op_mul_mat_vec_q(ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream); + +void ggml_cuda_op_mul_mat_vec_q_3D(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 65c7e5f1..90dfac92 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -37,6 +37,42 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest reinterpret_cast(y[ib].ds.y) = sum; } +static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded, const uint64_t stride) { + const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; + + if (ix0 >= kx0_padded) { + return; + } + + const int64_t ix1 = blockIdx.y; + + const int64_t i_padded = ix1*kx0_padded + ix0; + + block_q8_1 * y = (block_q8_1 *) vy; + + const int64_t ib = i_padded / QK8_1; // block index + const int64_t iqs = i_padded % QK8_1; // quant index + + const float xi = ix0 < kx ? x[ix1*stride + ix0] : 0.0f; + float amax = fabsf(xi); + float sum = xi; + + amax = warp_reduce_max(amax); + sum = warp_reduce_sum(sum); + + const float d = amax / 127; + const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); + + y[ib].qs[iqs] = q; + + if (iqs > 0) { + return; + } + + reinterpret_cast(y[ib].ds.x) = d; + reinterpret_cast(y[ib].ds.y) = sum; +} + template static __global__ void quantize_mmq_q8_1( const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) { @@ -164,3 +200,19 @@ void quantize_mmq_q8_1_cuda( break; } } + +void quantize_tensor_q8_1_cuda(const struct ggml_tensor * src, void * vy, const enum ggml_type type, cudaStream_t stream) { + GGML_ASSERT(src->ne[1] == 1 && src->ne[3] == 1); + GGML_ASSERT(src->type == GGML_TYPE_F32); + const int64_t src_padded_col_size = GGML_PAD(src->ne[0], MATRIX_ROW_PADDING); + GGML_ASSERT(src_padded_col_size % QK8_1 == 0); + if (src->ne[2] == 1 || ggml_is_contiguous(src)) { + quantize_row_q8_1_cuda((const float *)src->data, vy, src->ne[0], 1, 1, src_padded_col_size, type, stream); + return; + } + const int64_t block_num_x = (src_padded_col_size + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; + const dim3 num_blocks(block_num_x, src->ne[2]*src->ne[3], 1); + const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); + const uint64_t stride = src->nb[2]/sizeof(float); + quantize_q8_1<<>>((const float *)src->data, vy, src->ne[0], src_padded_col_size, stride); +} diff --git a/ggml/src/ggml-cuda/quantize.cuh b/ggml/src/ggml-cuda/quantize.cuh index 03bf322b..69622e18 100644 --- a/ggml/src/ggml-cuda/quantize.cuh +++ b/ggml/src/ggml-cuda/quantize.cuh @@ -22,3 +22,6 @@ void quantize_row_q8_1_cuda( void quantize_mmq_q8_1_cuda( const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream); + +// For now only applicable for tensors with ne[1] = 1, ne[3] = 1, and useful if ne[2] > 1 +void quantize_tensor_q8_1_cuda(const struct ggml_tensor * src, void * vy, const enum ggml_type type, cudaStream_t stream); diff --git a/src/llama.cpp b/src/llama.cpp index 0276b69c..ebc7a772 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3212,7 +3212,7 @@ static bool llama_kv_cache_init( ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size); //ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size); #else - ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_v, (kv_lora_rank + n_embd_head_qk_rope)*kv_size); + ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_v, kv_lora_rank + n_embd_head_qk_rope, kv_size); #endif ggml_format_name(kv, "cache_kv_l%d", i); cache.kv_l.push_back(kv); @@ -13579,6 +13579,7 @@ struct llm_build_context { cb(wk_b, "wk_b", il); q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); + //if (q_nope->ne[1] <= 32) q_nope = ggml_cont(ctx0, q_nope); cb(q_nope, "q_nope_perm", il); struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope); From b762db7c9264199c2d0f66e7d63e3b4884f3fc0c Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 27 Feb 2025 16:40:49 +0200 Subject: [PATCH 025/132] Option to use MLA without a transposed cache (#235) The `-mla` command line option turns into an int from a bool. mla = 0: use standard attention mla = 1: use MLA with transposed cache mla > 1: use MLA without transposed cache Co-authored-by: Iwan Kawrakow --- common/common.cpp | 7 +- common/common.h | 2 +- examples/llama-bench/llama-bench.cpp | 16 ++--- ggml/src/ggml-cuda/mmvq.cu | 26 ++----- include/llama.h | 2 +- src/llama.cpp | 102 ++++++++++++--------------- 6 files changed, 64 insertions(+), 91 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 464b4710..6359426f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -851,7 +851,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-mla" || arg == "--mla-use") { - params.mla_attn = true; + CHECK_ARG + params.mla_attn = std::stoi(argv[i]); return true; } if (arg == "-fmoe" || arg == "--fused-moe") { @@ -1514,7 +1515,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep }); options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks }); options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" }); - options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %s)", params.mla_attn ? "enabled" : "disabled" }); + options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn }); options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" }); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" "in conversation mode, this will be used as system prompt\n" @@ -3357,7 +3358,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); - fprintf(stream, "mla_attn: %s # default: false\n", params.mla_attn ? "true" : "false"); + fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn); fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); diff --git a/common/common.h b/common/common.h index 152fd1cf..ef5175f3 100644 --- a/common/common.h +++ b/common/common.h @@ -175,7 +175,7 @@ struct gpt_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention - bool mla_attn = false; // MLA + int mla_attn = false; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 438d2a7c..5756843a 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -232,7 +232,7 @@ struct cmd_params { std::vector main_gpu; std::vector no_kv_offload; std::vector flash_attn; - std::vector mla_attn; + std::vector mla_attn; std::vector> tensor_split; std::vector use_mmap; std::vector embeddings; @@ -264,7 +264,7 @@ static const cmd_params cmd_params_defaults = { /* main_gpu */ {0}, /* no_kv_offload */ {false}, /* flash_attn */ {false}, - /* mla_attn */ {false}, + /* mla_attn */ {0}, /* tensor_split */ {std::vector(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, @@ -300,7 +300,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -mg, --main-gpu (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); - printf(" -mla, --mla-attn <0|1> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str()); + printf(" -mla, --mla-attn <0|1|2> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" --numa (default: disabled)\n"); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); @@ -576,7 +576,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { invalid_param = true; break; } - auto p = string_split(argv[i], split_delim); + auto p = string_split(argv[i], split_delim); params.mla_attn.insert(params.mla_attn.end(), p.begin(), p.end()); } else if (arg == "-mmp" || arg == "--mmap") { if (++i >= argc) { @@ -726,7 +726,7 @@ struct cmd_params_instance { int main_gpu; bool no_kv_offload; bool flash_attn; - bool mla_attn; + int mla_attn; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -955,7 +955,7 @@ struct test { int main_gpu; bool no_kv_offload; bool flash_attn; - bool mla_attn; + int mla_attn; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -1097,13 +1097,13 @@ struct test { field == "n_threads" || field == "model_size" || field == "model_n_params" || field == "n_gpu_layers" || field == "main_gpu" || - field == "n_prompt" || field == "n_gen" || + field == "n_prompt" || field == "n_gen" || field == "mla_attn" || field == "avg_ns" || field == "stddev_ns") { return INT; } if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || - field == "flash_attn" || field == "mla_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || + field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || field == "fused_moe") { return BOOL; } diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 2a520d68..e17e77a3 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -158,11 +158,6 @@ static void mul_mat_vec_q_cuda( int64_t nwarps = 1; int64_t rows_per_cuda_block = 1; - //if (ne2 > 1) { - // printf("%s: ncols_x = %d, nrows_x = %d, nrows_y = %d, ncols_y = %d nrows_dst = %d, ne2 = %d nb02 = %zu, nb12 = %zu, nb2 = %zu\n", __func__, - // ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2); - //} - if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2 switch(ncols_y) { case 1: @@ -382,9 +377,8 @@ static void mul_mat_vec_iq3_s_q8_1_cuda( mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); } -namespace { -void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggml_type type, - const int64_t ne00, const int64_t ne10, const int64_t ne0, const int64_t ne2, +static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggml_type type, + const int64_t ne00, const int64_t ne0, const int64_t ne2, const int64_t nb02, const int64_t nb12, const int64_t nb2, const char * src0_dd_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, @@ -496,7 +490,6 @@ void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggml_type } } -} void ggml_cuda_op_mul_mat_vec_q_3D( ggml_backend_cuda_context & ctx, @@ -505,8 +498,6 @@ void ggml_cuda_op_mul_mat_vec_q_3D( const int64_t src1_padded_row_size, cudaStream_t stream) { const int64_t ne00 = src0->ne[0]; - const int64_t row_diff = row_high - row_low; - const int64_t ne10 = src1->ne[0]; GGML_ASSERT(ne10 % QK8_1 == 0); GGML_ASSERT(src0->ne[3] == 1 && src1->ne[3] == 1 && dst->ne[3] == 1); @@ -516,13 +507,10 @@ void ggml_cuda_op_mul_mat_vec_q_3D( int id = ggml_cuda_get_device(); - // the main device has a larger memory buffer to hold the results from all GPUs - // nrows_dst == nrows of the matrix that the kernel writes into - const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; const int64_t src1_row_size = ggml_row_size(GGML_TYPE_Q8_1, src1_padded_row_size); ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type, - ne00, ne10, ne0, dst->ne[2], + ne00, ne0, dst->ne[2], src0->nb[2], src1_row_size, dst->nb[2], src0_dd_i, src1_ddq_i, dst_dd_i, row_low, row_high, src1_ncols, @@ -538,8 +526,6 @@ void ggml_cuda_op_mul_mat_vec_q( const int64_t src1_padded_row_size, cudaStream_t stream) { const int64_t ne00 = src0->ne[0]; - const int64_t row_diff = row_high - row_low; - const int64_t ne10 = src1->ne[0]; GGML_ASSERT(ne10 % QK8_1 == 0); @@ -547,12 +533,8 @@ void ggml_cuda_op_mul_mat_vec_q( int id = ggml_cuda_get_device(); - // the main device has a larger memory buffer to hold the results from all GPUs - // nrows_dst == nrows of the matrix that the kernel writes into - const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; - ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type, - ne00, ne10, ne0, 1, 0, 0, 0, + ne00, ne0, 1, 0, 0, 0, src0_dd_i, src1_ddq_i, dst_dd_i, row_low, row_high, src1_ncols, src1_padded_row_size, stream); diff --git a/include/llama.h b/include/llama.h index beb6ecba..2b33701c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -383,7 +383,7 @@ extern "C" { bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool flash_attn; // whether to use flash attention [EXPERIMENTAL] - bool mla_attn; // whether to use MLA attention [EXPERIMENTAL] + int mla_attn; // whether to use MLA attention [EXPERIMENTAL] bool fused_moe_up_gate; // whether to use fused MoE up/down op [EXPERIMENTAL] // Abort callback diff --git a/src/llama.cpp b/src/llama.cpp index ebc7a772..f2c5f9d4 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -111,14 +111,6 @@ #define LLAMA_MAX_LAYERS 512 #define LLAMA_MAX_EXPERTS 256 // DeepSeekV2 -// -// === MLA cache -// If tou are desperate to reduce KV cache size, set MLA_USE_TRANSPOSED_CACHE to 0. -// TG perfornce will be slower (similar to no-MLA), but KV cache size will be cut to ~half. -// PP performance will be about the same as with MLA_USE_TRANSPOSED_CACHE = 1. -// -#define MLA_USE_TRANSPOSED_CACHE 1 - // // helpers // @@ -2518,7 +2510,7 @@ struct llama_cparams { bool causal_attn; bool offload_kqv; bool flash_attn; - bool mla_attn; + int mla_attn; bool fused_moe_up_gate; enum llama_pooling_type pooling_type; @@ -2695,9 +2687,7 @@ struct llama_kv_cache { // DeepSeek MLA std::vector kv_l; -#if MLA_USE_TRANSPOSED_CACHE std::vector kvt_l; -#endif std::vector ctxs; std::vector bufs; @@ -3175,9 +3165,9 @@ static bool llama_kv_cache_init( // DeepSeek MLA cache.kv_l.reserve(n_layer); -#if MLA_USE_TRANSPOSED_CACHE - cache.kvt_l.reserve(n_layer); -#endif + if (cparams.mla_attn == 1) { + cache.kvt_l.reserve(n_layer); + } bool warn = true; int n_mla = 0; @@ -3208,25 +3198,18 @@ static bool llama_kv_cache_init( const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t kv_lora_rank = hparams.n_lora_kv; LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); -#if MLA_USE_TRANSPOSED_CACHE - ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size); - //ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_k, (kv_lora_rank + n_embd_head_qk_rope)*kv_size); -#else - ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_v, kv_lora_rank + n_embd_head_qk_rope, kv_size); -#endif + auto kv_type = cparams.mla_attn == 1 ? cache.type_k : cache.type_v; + ggml_tensor * kv = ggml_new_tensor_2d(ctx, kv_type, kv_lora_rank + n_embd_head_qk_rope, kv_size); ggml_format_name(kv, "cache_kv_l%d", i); cache.kv_l.push_back(kv); -#if MLA_USE_TRANSPOSED_CACHE - ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_v, kv_lora_rank*kv_size); - ggml_format_name(kvt, "cache_kvt_l%d", i); - cache.kvt_l.push_back(kvt); -#endif + if (cparams.mla_attn == 1) { + ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_v, kv_lora_rank*kv_size); + ggml_format_name(kvt, "cache_kvt_l%d", i); + cache.kvt_l.push_back(kvt); + } n_mla++; } else { - //printf("Creating cache tensors:\n"); - //printf("n_embd_k_gqa = %d, kv_size = %d, n_head = %d, n_head_kv = %d, n_embd_head_k = %d\n", (int)n_embd_k_gqa, (int)kv_size, (int)n_head, (int)n_head_kv, (int)n_embd_head_k); - //k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); k = ggml_new_tensor_2d(ctx, type_k, n_embd_head_k, n_head_kv*kv_size); v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); ggml_format_name(k, "cache_k_l%d", i); @@ -8940,7 +8923,7 @@ struct llm_build_context { const int32_t n_ctx_orig; const bool flash_attn; - const bool mla_attn; + const int mla_attn; const bool fused_moe_up_gate; const enum llama_pooling_type pooling_type; @@ -13546,20 +13529,22 @@ struct llm_build_context { if (lctx.cparams.mla_attn && model.layers[il].wk_b && model.layers[il].wv_b) { -#if MLA_USE_TRANSPOSED_CACHE - ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, - ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size), ggml_row_size(kv_self.kvt_l[il]->type, kv_head)); - cb(kv_cache_trans_view, "kv_cache_trans_view", il); + ggml_tensor * kv_cache_trans; - // note: storing transposed c^KV in the transposed KV cache - ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view)); + if (lctx.cparams.mla_attn == 1) { + ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, + ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size), ggml_row_size(kv_self.kvt_l[il]->type, kv_head)); + cb(kv_cache_trans_view, "kv_cache_trans_view", il); - ggml_tensor * kv_cache_trans = ggml_view_2d(ctx0, kv_self.kvt_l[il], - n_kv, kv_lora_rank, - ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size), - 0); - cb(kv_cache_trans, "kv_cache_trans", il); -#endif + // note: storing transposed c^KV in the transposed KV cache + ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_transpose(ctx0, kv_compressed), kv_cache_trans_view)); + + kv_cache_trans = ggml_view_2d(ctx0, kv_self.kvt_l[il], + n_kv, kv_lora_rank, + ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size), + 0); + cb(kv_cache_trans, "kv_cache_trans", il); + } ggml_tensor * kvr = ggml_concat(ctx0, kv_compressed, ggml_permute(ctx0, k_rope, 0, 2, 1, 3), 0); cb(kvr, "kvr", il); @@ -13607,15 +13592,15 @@ struct llm_build_context { cb(kq, "kq_soft_max_ext_perm", il); } -#if !MLA_USE_TRANSPOSED_CACHE - ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], - kv_lora_rank, n_kv, - ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); - cb(kv_cache, "kv_cache_lora", il); + if (lctx.cparams.mla_attn > 1) { + ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], + kv_lora_rank, n_kv, + ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); + cb(kv_cache, "kv_cache_lora", il); - ggml_tensor * kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora)); - cb(kv_cache_trans, "kv_cache_trans", il); -#endif + kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora)); + cb(kv_cache_trans, "kv_cache_trans", il); + } struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq); cb(kqv_compressed, "kqv_compressed", il); @@ -17658,7 +17643,7 @@ struct llama_context_params llama_context_default_params() { /*.embeddings =*/ false, /*.offload_kqv =*/ true, /*.flash_attn =*/ false, - /*.mla_attn =*/ false, + /*.mla_attn =*/ 0, /*.fused_moe_up_gate =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, @@ -18140,18 +18125,23 @@ struct llama_context * llama_new_context_with_model( kv_type = kv->type; } -#if MLA_USE_TRANSPOSED_CACHE for (auto & kvt : ctx->kv_self.kvt_l) { memory_size_kvt += ggml_nbytes(kvt); kvt_type = kvt->type; } -#endif if (memory_size_kv + memory_size_kvt > 0) { - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__, - (float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f), - ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f), - ggml_type_name(kvt_type), (float)memory_size_kvt / (1024.0f * 1024.0f)); + if (cparams.mla_attn == 1) { + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__, + (float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f), + ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f), + ggml_type_name(kvt_type), (float)memory_size_kvt / (1024.0f * 1024.0f)); + } else { + GGML_ASSERT(memory_size_kvt == 0); + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T: not used\n", __func__, + (float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f), + ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f)); + } } } From a79ab8f34222e1e0142a30eaa97e78ad077abca9 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 1 Mar 2025 08:25:27 +0200 Subject: [PATCH 026/132] Reduce size of compute buffers (#237) * This reduces compute buffer size for MLA * This should accomplish it for standard attention * Much better * Better concat for contiguous tensors If all the op does is to concatenate the second tensor to the first, why would we want to have a loop? --------- Co-authored-by: Iwan Kawrakow --- common/common.cpp | 8 + common/common.h | 3 +- examples/llama-bench/llama-bench.cpp | 35 ++++- ggml/src/ggml-cuda/concat.cu | 30 +++- ggml/src/ggml.c | 20 +++ include/llama.h | 1 + src/llama.cpp | 218 ++++++++++++++++++--------- 7 files changed, 236 insertions(+), 79 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 6359426f..5c9070da 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -855,6 +855,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.mla_attn = std::stoi(argv[i]); return true; } + if (arg == "-amb" || arg == "--attention-max-batch") { + CHECK_ARG + params.attn_max_batch = std::stoi(argv[i]); + return true; + } if (arg == "-fmoe" || arg == "--fused-moe") { params.fused_moe_up_gate = true; return true; @@ -1516,6 +1521,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks }); options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" }); options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn }); + options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch}); options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" }); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" "in conversation mode, this will be used as system prompt\n" @@ -2360,6 +2366,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.offload_kqv = !params.no_kv_offload; cparams.flash_attn = params.flash_attn; cparams.mla_attn = params.mla_attn; + cparams.attn_max_batch = params.attn_max_batch; cparams.fused_moe_up_gate = params.fused_moe_up_gate; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); @@ -3359,6 +3366,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn); + fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch); fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); diff --git a/common/common.h b/common/common.h index ef5175f3..f35f3558 100644 --- a/common/common.h +++ b/common/common.h @@ -175,7 +175,8 @@ struct gpt_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention - int mla_attn = false; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache + int mla_attn = 0; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache + int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false) bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 5756843a..a08cb762 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -233,6 +233,7 @@ struct cmd_params { std::vector no_kv_offload; std::vector flash_attn; std::vector mla_attn; + std::vector attn_max_batch; std::vector> tensor_split; std::vector use_mmap; std::vector embeddings; @@ -265,6 +266,7 @@ static const cmd_params cmd_params_defaults = { /* no_kv_offload */ {false}, /* flash_attn */ {false}, /* mla_attn */ {0}, + /* attn_max_batch */ {0}, /* tensor_split */ {std::vector(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, @@ -301,6 +303,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); printf(" -mla, --mla-attn <0|1|2> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str()); + printf(" -amb, --attn-max-batch (default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" --numa (default: disabled)\n"); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); @@ -578,6 +581,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = string_split(argv[i], split_delim); params.mla_attn.insert(params.mla_attn.end(), p.begin(), p.end()); + } else if (arg == "-amb" || arg == "--attn-max-batch") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.attn_max_batch.insert(params.attn_max_batch.end(), p.begin(), p.end()); } else if (arg == "-mmp" || arg == "--mmap") { if (++i >= argc) { invalid_param = true; @@ -690,6 +700,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; } if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; } if (params.mla_attn.empty()) { params.mla_attn = cmd_params_defaults.mla_attn; } + if (params.attn_max_batch.empty()){ params.attn_max_batch = cmd_params_defaults.attn_max_batch; } if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; } if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; } @@ -727,6 +738,7 @@ struct cmd_params_instance { bool no_kv_offload; bool flash_attn; int mla_attn; + int attn_max_batch; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -773,6 +785,7 @@ struct cmd_params_instance { cparams.offload_kqv = !no_kv_offload; cparams.flash_attn = flash_attn; cparams.mla_attn = mla_attn; + cparams.attn_max_batch = attn_max_batch; cparams.fused_moe_up_gate = fmoe; cparams.embeddings = embeddings; @@ -799,6 +812,7 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & nkvo : params.no_kv_offload) for (const auto & fa : params.flash_attn) for (const auto & mla : params.mla_attn) + for (const auto & amb : params.attn_max_batch) for (const auto & nt : params.n_threads) { for (const auto & n_prompt : params.n_prompt) { if (n_prompt == 0) { @@ -821,6 +835,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, /* .mla_attn = */ mla, + /* .attn_max_b = */ amb, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -852,6 +867,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, /* .mla_attn = */ mla, + /* .attn_max_b = */ amb, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -883,6 +899,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, /* .mla_attn = */ mla, + /* .attn_max_b = */ amb, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -914,6 +931,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .no_kv_offload= */ nkvo, /* .flash_attn = */ fa, /* .mla_attn = */ mla, + /* .attn_max_b = */ amb, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -956,6 +974,7 @@ struct test { bool no_kv_offload; bool flash_attn; int mla_attn; + int attn_max_batch; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -987,6 +1006,7 @@ struct test { no_kv_offload = inst.no_kv_offload; flash_attn = inst.flash_attn; mla_attn = inst.mla_attn; + attn_max_batch = inst.attn_max_batch; tensor_split = inst.tensor_split; use_mmap = inst.use_mmap; embeddings = inst.embeddings; @@ -1081,7 +1101,7 @@ struct test { "n_batch", "n_ubatch", "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", - "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", + "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", @@ -1097,7 +1117,7 @@ struct test { field == "n_threads" || field == "model_size" || field == "model_n_params" || field == "n_gpu_layers" || field == "main_gpu" || - field == "n_prompt" || field == "n_gen" || field == "mla_attn" || + field == "n_prompt" || field == "n_gen" || field == "mla_attn" || field == "attn_max_batch" || field == "avg_ns" || field == "stddev_ns") { return INT; } @@ -1138,7 +1158,7 @@ struct test { std::to_string(n_batch), std::to_string(n_ubatch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), - std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn), + std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn), std::to_string(attn_max_batch), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(fmoe), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), @@ -1305,6 +1325,9 @@ struct markdown_printer : public printer { if (field == "mla_attn") { return 3; } + if (field == "attn_max_batch") { + return 5; + } if (field == "use_mmap") { return 4; } @@ -1345,6 +1368,9 @@ struct markdown_printer : public printer { if (field == "mla_attn") { return "mla"; } + if (field == "attn_max_batch") { + return "amb"; + } if (field == "use_mmap") { return "mmap"; } @@ -1403,6 +1429,9 @@ struct markdown_printer : public printer { if (params.mla_attn.size() > 1 || params.mla_attn != cmd_params_defaults.mla_attn) { fields.emplace_back("mla_attn"); } + if (params.attn_max_batch.size() > 1 || params.attn_max_batch != cmd_params_defaults.mla_attn) { + fields.emplace_back("attn_max_batch"); + } if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) { fields.emplace_back("tensor_split"); } diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index dac10ec3..4bde6d69 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -164,7 +164,12 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *)dst->data; - if (dim != 3) { + if (dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1)) { + const size_t size0 = ggml_nbytes(src0); + const size_t size1 = ggml_nbytes(src1); + CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream)); + } else { for (int i3 = 0; i3 < dst->ne[3]; i3++) { concat_f32_cuda( src0_d + i3 * (src0->nb[3] / 4), @@ -173,13 +178,24 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { src0->ne[0], src0->ne[1], src0->ne[2], dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); } - } else { - const size_t size0 = ggml_nbytes(src0); - const size_t size1 = ggml_nbytes(src1); - - CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream)); - CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream)); } + + //if (dim != 3) { + // for (int i3 = 0; i3 < dst->ne[3]; i3++) { + // concat_f32_cuda( + // src0_d + i3 * (src0->nb[3] / 4), + // src1_d + i3 * (src1->nb[3] / 4), + // dst_d + i3 * ( dst->nb[3] / 4), + // src0->ne[0], src0->ne[1], src0->ne[2], + // dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); + // } + //} else { + // const size_t size0 = ggml_nbytes(src0); + // const size_t size1 = ggml_nbytes(src1); + + // CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream)); + // CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream)); + //} } else { dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]); concat_f32_non_cont<<>>( diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 80dd25ff..91c0c5db 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -12627,6 +12627,26 @@ static void ggml_compute_forward_concat_f32( GGML_ASSERT(dim >= 0 && dim < 4); + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst) && + (dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1))) { + // simply copy the data + const int64_t size_src_0 = ggml_nbytes(src0); + const int64_t size_src_1 = ggml_nbytes(src1); + const int64_t block_size = 4096; + const int64_t num_blocks = (size_src_0 + size_src_1 + block_size - 1)/block_size; + for (int64_t i_block = ith; i_block < num_blocks; i_block += nth) { + const int64_t start = i_block*block_size; + if (start < size_src_0) { + int64_t copy_size = MIN(block_size, size_src_0 - start); + memcpy((char *)dst->data + start, (char *)src0->data + start, copy_size); + } else { + int64_t copy_size = MIN(block_size, size_src_0 + size_src_1 - start); + memcpy((char *)dst->data + start, (char *)src1->data + start - size_src_0, copy_size); + } + } + return; + } + int64_t o[4] = {0, 0, 0, 0}; o[dim] = src0->ne[dim]; diff --git a/include/llama.h b/include/llama.h index 2b33701c..bb43aebc 100644 --- a/include/llama.h +++ b/include/llama.h @@ -384,6 +384,7 @@ extern "C" { bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool flash_attn; // whether to use flash attention [EXPERIMENTAL] int mla_attn; // whether to use MLA attention [EXPERIMENTAL] + int attn_max_batch; // maximum batch size for attention computations [EXPERIMENTAL] bool fused_moe_up_gate; // whether to use fused MoE up/down op [EXPERIMENTAL] // Abort callback diff --git a/src/llama.cpp b/src/llama.cpp index f2c5f9d4..0dcc78dc 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2511,6 +2511,7 @@ struct llama_cparams { bool offload_kqv; bool flash_attn; int mla_attn; + int attn_max_batch; bool fused_moe_up_gate; enum llama_pooling_type pooling_type; @@ -8774,45 +8775,8 @@ static struct ggml_tensor * llm_build_kqv( cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens); } else { - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); - //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - - if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) { - // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs - // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - } - - if (model.arch == LLM_ARCH_GROK) { - // need to do the following: - // multiply by attn_output_multiplyer of 0.08838834764831845 - // and then : - // kq = 30 * tanh(kq / 30) - // before the softmax below - - //try from phi2 - //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - - //kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); - //kq = ggml_scale(ctx, kq, 30); - - kq = ggml_softcap(ctx, kq, 0.08838834764831845f/30.0f, 30.f); - } - - if (hparams.attn_soft_cap) { - //kq = ggml_softcap(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping); - kq = ggml_softcap_max(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias, - 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping); - } else { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); - } - cb(kq, "kq_soft_max_ext", il); - - GGML_ASSERT(kv.size == n_ctx); - - // split cached v into n_head heads + // split cached v into n_head heads struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], n_kv, n_embd_head_v, n_head_kv, @@ -8821,14 +8785,98 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); + auto kq_size = k->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); + if (cparams.attn_max_batch == 0 || cparams.attn_max_batch >= kq_size || k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2]) { + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); - struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); - cb(kqv_merged, "kqv_merged", il); + //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens); - cb(cur, "kqv_merged_cont", il); + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } + + if (model.arch == LLM_ARCH_GROK) { + // need to do the following: + // multiply by attn_output_multiplyer of 0.08838834764831845 + // and then : + // kq = 30 * tanh(kq / 30) + // before the softmax below + + //try from phi2 + //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + + //kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); + //kq = ggml_scale(ctx, kq, 30); + + kq = ggml_softcap(ctx, kq, 0.08838834764831845f/30.0f, 30.f); + } + + if (hparams.attn_soft_cap) { + //kq = ggml_softcap(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping); + kq = ggml_softcap_max(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias, + 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping); + } else { + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); + } + cb(kq, "kq_soft_max_ext", il); + + GGML_ASSERT(kv.size == n_ctx); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv", il); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + + cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens); + cb(cur, "kqv_merged_cont", il); + } + else { + // For now we will not support this option if k->ne[2] != q->ne[2] || v->ne[2] != q->ne[2]; + GGML_ASSERT(k->ne[2] == v->ne[2] && k->ne[2] == q->ne[2]); + int n_step = (kq_size + cparams.attn_max_batch - 1)/cparams.attn_max_batch; + n_step = std::min(n_step, int(k->ne[2])); + int n_per_step = (q->ne[2] + n_step - 1)/n_step; + auto r2k = q->ne[2] / k->ne[2]; + auto r2v = q->ne[2] / v->ne[2]; + n_step = q->ne[2]; + n_per_step = 1; + ggml_tensor * kqv; + for (int i12 = 0; i12 < q->ne[2]; i12 += n_per_step) { + int this_ne12 = i12 + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i12; + int i02 = i12/r2k; + auto k_i = ggml_view_3d(ctx, k, k->ne[0], k->ne[1], this_ne12, k->nb[1], k->nb[2], k->nb[2]*i02); + auto q_i = ggml_view_3d(ctx, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i12); + auto kq_i = ggml_mul_mat(ctx, k_i, q_i); + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) { + ggml_mul_mat_set_prec(kq_i, GGML_PREC_F32); + } + if (model.arch == LLM_ARCH_GROK) { + kq_i = ggml_softcap(ctx, kq_i, 0.08838834764831845f/30.0f, 30.f); + } + if (hparams.attn_soft_cap) { + kq_i = ggml_softcap_max(ctx, kq_i, kq_mask, kq_scale, hparams.f_max_alibi_bias, + 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping); + } else { + kq_i = ggml_soft_max_ext(ctx, kq_i, kq_mask, kq_scale, hparams.f_max_alibi_bias); + } + i02 = i12 / r2v; + auto v_i = ggml_view_3d(ctx, v, v->ne[0], v->ne[1], this_ne12, v->nb[1], v->nb[2], v->nb[2]*i02); + auto kqv_i = ggml_mul_mat(ctx, v_i, kq_i); + if (i12 == 0) { + kqv = kqv_i; + } else { + kqv = ggml_concat(ctx, kqv, kqv_i, 2); + } + } + ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_v*n_head, n_tokens); + cb(cur, "kqv_merged_cont", il); + } } ggml_build_forward_expand(graph, cur); @@ -8924,6 +8972,7 @@ struct llm_build_context { const bool flash_attn; const int mla_attn; + const int attn_max_batch; const bool fused_moe_up_gate; const enum llama_pooling_type pooling_type; @@ -8976,6 +9025,7 @@ struct llm_build_context { n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), mla_attn (cparams.mla_attn), + attn_max_batch (cparams.attn_max_batch), fused_moe_up_gate(cparams.fused_moe_up_gate), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), @@ -13572,25 +13622,6 @@ struct llm_build_context { ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0); cb(q, "q", il); - if (!pp_opt) { - q = ggml_permute(ctx0, q, 0, 2, 1, 3); - cb(q, "q_perm", il); - } - ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q); - cb(kq, "kq", il); - - if (!pp_opt) { - kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); - cb(kq, "kq_perm", il); - } - - kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); - cb(kq, "kq_soft_max_ext", il); - - if (!pp_opt) { - kq = ggml_permute(ctx0, kq, 0, 2, 1, 3); - cb(kq, "kq_soft_max_ext_perm", il); - } if (lctx.cparams.mla_attn > 1) { ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], @@ -13602,12 +13633,60 @@ struct llm_build_context { cb(kv_cache_trans, "kv_cache_trans", il); } - struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq); - cb(kqv_compressed, "kqv_compressed", il); + ggml_tensor * kqv_compressed; + auto kq_size = kv_cache->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); // K*Q in MiB + if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kq_size) { + if (!pp_opt) { + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + cb(q, "q_perm", il); + } - if (!pp_opt) { - kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); - cb(kqv_compressed, "kqv_compressed_perm", il); + ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q); + cb(kq, "kq", il); + + if (!pp_opt) { + kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); + cb(kq, "kq_perm", il); + } + + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + if (!pp_opt) { + kq = ggml_permute(ctx0, kq, 0, 2, 1, 3); + cb(kq, "kq_soft_max_ext_perm", il); + } + + kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq); + cb(kqv_compressed, "kqv_compressed", il); + + if (!pp_opt) { + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); + cb(kqv_compressed, "kqv_compressed_perm", il); + } + + } else { + + int n_step = (kq_size + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch; + n_step = std::min(n_step, int(q->ne[2])); + int n_per_step = (q->ne[2] + n_step - 1)/n_step; + + //printf("kq size would be %ld MiB -> splitting kqv computation into %d steps\n", kq_size, n_step); + + for (int i_head = 0; i_head < q->ne[2]; i_head += n_per_step) { + int this_ne12 = i_head + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i_head; + ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i_head); + ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i); + kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i); + if (i_head == 0) { + kqv_compressed = kqv_i; + } else { + kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2); + } + ggml_build_forward_expand(gf, kqv_compressed); + } + cb(kqv_compressed, "kqv_compressed", il); } struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, @@ -17644,6 +17723,7 @@ struct llama_context_params llama_context_default_params() { /*.offload_kqv =*/ true, /*.flash_attn =*/ false, /*.mla_attn =*/ 0, + /*.attn_max_batch =*/ 0, /*.fused_moe_up_gate =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, @@ -17844,6 +17924,7 @@ struct llama_context * llama_new_context_with_model( cparams.offload_kqv = params.offload_kqv; cparams.flash_attn = params.flash_attn; cparams.mla_attn = params.mla_attn; + cparams.attn_max_batch = params.attn_max_batch; cparams.fused_moe_up_gate= params.fused_moe_up_gate; cparams.pooling_type = params.pooling_type; @@ -17912,6 +17993,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn); + LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch); LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); From ef9a3d17b52bb5f6d55f7ef7e05e41e22f2ad81d Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 1 Mar 2025 17:12:58 +0200 Subject: [PATCH 027/132] A better way to measure the cost of ggml_barrier (#238) Co-authored-by: Iwan Kawrakow --- ggml/src/ggml.c | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 91c0c5db..7ba5e1ad 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -21459,14 +21459,26 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { /*.shared=*/ state->shared, }; +#if IK_PRINT_TIMING + int64_t t_start = ggml_time_us(); + int64_t t_eval = 0; +#endif + for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { struct ggml_tensor * node = cgraph->nodes[node_n]; if (ggml_is_noop(node)) continue; +#if IK_PRINT_TIMING + int64_t tim1 = ggml_time_us(); +#endif if (ggml_compute_forward(¶ms, node, node_n < cgraph->n_nodes-1 ? cgraph->nodes[node_n+1] : NULL)) { ++node_n; } +#if IK_PRINT_TIMING + int64_t tim2 = ggml_time_us(); + t_eval += tim2 - tim1; +#endif if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { state->shared->ec = GGML_STATUS_ABORTED; @@ -21478,6 +21490,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { break; } } +#if IK_PRINT_TIMING + int64_t t_end = ggml_time_us(); + if (state->ith == 0) printf("ggml_barrier(...): %d us\n", (int)(t_end - t_start - t_eval)); +#endif return 0; } From a89adaa78f505675be7be6180f419b4b0158c15a Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 2 Mar 2025 13:47:38 +0200 Subject: [PATCH 028/132] SER - Smart Expert Reduction (#239) * A better way to measure the cost of ggml_barrier * Smart expert selection * Add ser option to llama-bench --------- Co-authored-by: Iwan Kawrakow --- common/common.cpp | 35 +++++++ common/common.h | 2 + examples/llama-bench/llama-bench.cpp | 65 +++++++++++- ggml/include/ggml.h | 13 +++ ggml/src/ggml-cuda.cu | 16 ++- ggml/src/ggml-cuda/argsort.cu | 47 +++++++-- ggml/src/ggml-cuda/argsort.cuh | 2 + ggml/src/ggml-cuda/getrows.cu | 17 ++-- ggml/src/ggml.c | 143 ++++++++++++++++++++++++++- include/llama.h | 2 + src/llama.cpp | 15 ++- 11 files changed, 330 insertions(+), 27 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 5c9070da..e62944b9 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -322,6 +322,26 @@ bool parse_buft_overrides(const std::string& value, std::vector +std::vector> string_split_pairs(const std::string & str, char delim) { + std::vector> values; + std::istringstream str_stream(str); + std::string token; + T1 first_value; + int i = 0; + while (std::getline(str_stream, token, delim)) { + std::istringstream token_stream(token); + if (i%2 == 0) { + token_stream >> first_value; + } else { + T2 value; + token_stream >> value; + values.emplace_back(first_value, value); + } + i++; + } + return values; +} } #define CHECK_ARG if (++i >= argc) { invalid_param = true; return true; } @@ -864,6 +884,17 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.fused_moe_up_gate = true; return true; } + if (arg == "-ser" || arg == "--smart-expert-reduction") { + CHECK_ARG + auto values = string_split_pairs(argv[i], ','); + if (values.size() == 1) { + params.min_experts = values.front().first; + params.thresh_experts = values.front().second; + } else { + invalid_param = true; + } + return true; + } if (arg == "-co" || arg == "--color") { params.use_color = true; return true; @@ -1523,6 +1554,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn }); options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch}); options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" }); + options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts}); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" "in conversation mode, this will be used as system prompt\n" "(default: '%s')", params.prompt.c_str() }); @@ -2368,6 +2400,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.mla_attn = params.mla_attn; cparams.attn_max_batch = params.attn_max_batch; cparams.fused_moe_up_gate = params.fused_moe_up_gate; + cparams.min_experts = params.min_experts; + cparams.thresh_experts = params.thresh_experts; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); @@ -3368,6 +3402,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn); fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch); fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false"); + fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); diff --git a/common/common.h b/common/common.h index f35f3558..f6a55885 100644 --- a/common/common.h +++ b/common/common.h @@ -178,6 +178,8 @@ struct gpt_params { int mla_attn = 0; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false) bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models + int min_experts = -1; + float thresh_experts = 0; bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index a08cb762..167525bc 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -215,6 +215,9 @@ static std::string pair_str(const std::pair & p) { return buf; } +// Ser = Smart Expert Reduction +using Ser = std::pair; + struct cmd_params { std::vector model; std::vector n_prompt; @@ -234,6 +237,7 @@ struct cmd_params { std::vector flash_attn; std::vector mla_attn; std::vector attn_max_batch; + std::vector ser; std::vector> tensor_split; std::vector use_mmap; std::vector embeddings; @@ -267,6 +271,7 @@ static const cmd_params cmd_params_defaults = { /* flash_attn */ {false}, /* mla_attn */ {0}, /* attn_max_batch */ {0}, + /* ser */ {{-1,0.0f}}, /* tensor_split */ {std::vector(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, @@ -304,6 +309,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); printf(" -mla, --mla-attn <0|1|2> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str()); printf(" -amb, --attn-max-batch (default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str()); + printf(" -ser, --smart-expert-reduction (default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" --numa (default: disabled)\n"); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); @@ -387,6 +393,28 @@ bool parse_buft_overrides(const std::string& value, std::vector +std::vector> string_split_pairs(const std::string & str, char delim) { + std::vector> values; + std::istringstream str_stream(str); + std::string token; + T1 first_value; + int i = 0; + while (std::getline(str_stream, token, delim)) { + std::istringstream token_stream(token); + if (i%2 == 0) { + token_stream >> first_value; + if (token_stream.fail()) return {}; + } else { + T2 value; + token_stream >> value; + if (token_stream.fail()) return {}; + values.emplace_back(first_value, value); + } + i++; + } + return values; +} } static cmd_params parse_cmd_params(int argc, char ** argv) { @@ -588,6 +616,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = string_split(argv[i], split_delim); params.attn_max_batch.insert(params.attn_max_batch.end(), p.begin(), p.end()); + } else if (arg == "-ser" || arg == "--smart-expert-reduction") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split_pairs(argv[i], split_delim); + params.ser.insert(params.ser.end(), p.begin(), p.end()); } else if (arg == "-mmp" || arg == "--mmap") { if (++i >= argc) { invalid_param = true; @@ -701,6 +736,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; } if (params.mla_attn.empty()) { params.mla_attn = cmd_params_defaults.mla_attn; } if (params.attn_max_batch.empty()){ params.attn_max_batch = cmd_params_defaults.attn_max_batch; } + if (params.ser.empty()) { params.ser = cmd_params_defaults.ser; } if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; } if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; } @@ -739,6 +775,7 @@ struct cmd_params_instance { bool flash_attn; int mla_attn; int attn_max_batch; + Ser ser; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -787,6 +824,8 @@ struct cmd_params_instance { cparams.mla_attn = mla_attn; cparams.attn_max_batch = attn_max_batch; cparams.fused_moe_up_gate = fmoe; + cparams.min_experts = ser.first; + cparams.thresh_experts = ser.second; cparams.embeddings = embeddings; return cparams; @@ -813,6 +852,7 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & fa : params.flash_attn) for (const auto & mla : params.mla_attn) for (const auto & amb : params.attn_max_batch) + for (const auto & ser : params.ser) for (const auto & nt : params.n_threads) { for (const auto & n_prompt : params.n_prompt) { if (n_prompt == 0) { @@ -836,6 +876,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .flash_attn = */ fa, /* .mla_attn = */ mla, /* .attn_max_b = */ amb, + /* .ser = */ ser, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -868,6 +909,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .flash_attn = */ fa, /* .mla_attn = */ mla, /* .attn_max_b = */ amb, + /* .ser = */ ser, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -900,6 +942,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .flash_attn = */ fa, /* .mla_attn = */ mla, /* .attn_max_b = */ amb, + /* .ser = */ ser, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -932,6 +975,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .flash_attn = */ fa, /* .mla_attn = */ mla, /* .attn_max_b = */ amb, + /* .ser = */ ser, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -975,6 +1019,7 @@ struct test { bool flash_attn; int mla_attn; int attn_max_batch; + Ser ser; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -1007,6 +1052,7 @@ struct test { flash_attn = inst.flash_attn; mla_attn = inst.mla_attn; attn_max_batch = inst.attn_max_batch; + ser = inst.ser; tensor_split = inst.tensor_split; use_mmap = inst.use_mmap; embeddings = inst.embeddings; @@ -1101,7 +1147,7 @@ struct test { "n_batch", "n_ubatch", "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", - "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", + "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser", "tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", @@ -1149,6 +1195,11 @@ struct test { tensor_split_str += "/"; } } + auto ser_to_string = [] (const Ser& ser) { + std::ostringstream str; + str << ser.first << ',' << ser.second; + return str.str(); + }; std::vector values = { build_commit, std::to_string(build_number), std::to_string(cuda), std::to_string(vulkan), std::to_string(vulkan), @@ -1158,7 +1209,8 @@ struct test { std::to_string(n_batch), std::to_string(n_ubatch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), - std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn), std::to_string(attn_max_batch), + std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), + std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(fmoe), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), @@ -1328,6 +1380,9 @@ struct markdown_printer : public printer { if (field == "attn_max_batch") { return 5; } + if (field == "ser") { + return 10; + } if (field == "use_mmap") { return 4; } @@ -1371,6 +1426,9 @@ struct markdown_printer : public printer { if (field == "attn_max_batch") { return "amb"; } + if (field == "attn_max_batch") { + return "ser"; + } if (field == "use_mmap") { return "mmap"; } @@ -1432,6 +1490,9 @@ struct markdown_printer : public printer { if (params.attn_max_batch.size() > 1 || params.attn_max_batch != cmd_params_defaults.mla_attn) { fields.emplace_back("attn_max_batch"); } + if (params.ser.size() > 1 || params.ser != cmd_params_defaults.ser) { + fields.emplace_back("ser"); + } if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) { fields.emplace_back("tensor_split"); } diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index d12b90d0..91219d4a 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -597,6 +597,7 @@ extern "C" { GGML_OP_ARANGE, GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, + GGML_OP_ARGSORT_THRESH, GGML_OP_LEAKY_RELU, GGML_OP_SOFTCAP, GGML_OP_SOFT_CAP_MAX, @@ -1913,6 +1914,12 @@ extern "C" { struct ggml_tensor * a, enum ggml_sort_order order); + GGML_API struct ggml_tensor * ggml_argsort_thresh( + struct ggml_context * ctx, + struct ggml_tensor * a, + int min_entries, + float threshold); + GGML_API struct ggml_tensor * ggml_arange( struct ggml_context * ctx, float start, @@ -1924,6 +1931,12 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a, int k); + GGML_API struct ggml_tensor * ggml_top_k_thresh( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k, + int min_entries, + float thresh); #define GGML_KQ_MASK_PAD 32 diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index bc960678..85df0694 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2133,7 +2133,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * for (int64_t id = 0; id < n_ids; id++) { const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - GGML_ASSERT(i02 >= 0 && i02 < n_as); + if (i02 < 0 || i02 >= n_as) continue; + //GGML_ASSERT(i02 >= 0 && i02 < n_as); const int64_t i11 = id % ne11; const int64_t i12 = iid1; @@ -2162,7 +2163,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * for (int64_t id = 0; id < n_ids; id++) { const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + if (i02 < 0 || i02 >= n_as) continue; + //GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); if (row_id_i != i02) { continue; @@ -2301,7 +2303,8 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor for (int64_t id = 0; id < n_ids; id++) { const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - GGML_ASSERT(i02 >= 0 && i02 < n_as); + if (i02 < 0 || i02 >= n_as) continue; + //GGML_ASSERT(i02 >= 0 && i02 < n_as); const int64_t i11 = id % ne11; const int64_t i12 = iid1; @@ -2362,7 +2365,8 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor for (int64_t id = 0; id < n_ids; id++) { const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + if (row_id_i < 0 || row_id_i >= n_as) continue; + //GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); if (row_id_i != i02) { continue; @@ -2637,6 +2641,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_ARGSORT: ggml_cuda_op_argsort(ctx, dst); break; + case GGML_OP_ARGSORT_THRESH: + ggml_cuda_op_argsort_thresh(ctx, dst); + break; case GGML_OP_FLASH_ATTN_EXT: ggml_cuda_flash_attn_ext(ctx, dst); break; @@ -3252,6 +3259,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_POOL_2D: case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: + case GGML_OP_ARGSORT_THRESH: case GGML_OP_ACC: case GGML_OP_GROUP_NORM: case GGML_OP_UPSCALE: diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 607ded85..1734b771 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -8,7 +8,8 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) { } template -static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) { +static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad, + int min_experts, float thresh_experts) { // bitonic sort int col = threadIdx.x; int row = blockIdx.y; @@ -51,9 +52,18 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n } } - // copy the result to dst without the padding - if (col < ncols) { - dst[row * ncols + col] = dst_row[col]; + if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) { + __syncthreads(); + float max_val = x_row[dst_row[0]]; + if (col < ncols) { + dst[row * ncols + col] = col < min_experts || x_row[dst_row[col]] >= thresh_experts*max_val ? dst_row[col] : -1; + } + } + else { + // copy the result to dst without the padding + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } } } @@ -65,7 +75,8 @@ static int next_power_of_2(int x) { return n; } -static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) { +static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, + ggml_sort_order order, int min_experts, float thresh_experts, cudaStream_t stream) { // bitonic sort requires ncols to be power of 2 const int ncols_pad = next_power_of_2(ncols); @@ -77,9 +88,9 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); if (order == GGML_SORT_ORDER_ASC) { - k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad); + k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts); } else if (order == GGML_SORT_ORDER_DESC) { - k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad); + k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts); } else { GGML_ABORT("fatal error"); } @@ -100,5 +111,25 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; - argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream); + argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, -1, 0.f, stream); +} + +void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + int min_experts = dst->op_params[0]; + float thresh; + memcpy(&thresh, dst->op_params + 1, sizeof(float)); + + argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, GGML_SORT_ORDER_DESC, min_experts, thresh, stream); } diff --git a/ggml/src/ggml-cuda/argsort.cuh b/ggml/src/ggml-cuda/argsort.cuh index 68a00154..4bafa2d7 100644 --- a/ggml/src/ggml-cuda/argsort.cuh +++ b/ggml/src/ggml-cuda/argsort.cuh @@ -1,3 +1,5 @@ #include "common.cuh" void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_argsort_thresh(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 4c370323..973b6526 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -4,7 +4,7 @@ template static __global__ void k_get_rows( const void * src0, const int32_t * src1, dst_t * dst, - int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ + int64_t ne00, int64_t ne01, /*int64_t ne02, int64_t ne03,*/ /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ /*size_t s0,*/ size_t s1, size_t s2, size_t s3, /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, @@ -31,7 +31,11 @@ static __global__ void k_get_rows( // dequantize dfloat2 v; - dequantize_kernel(src0_row, ib, iqs, v); + if (i01 >= 0 && i01 < ne01) { + dequantize_kernel(src0_row, ib, iqs, v); + } else { + v.x = v.y = 0; + } dst_row[iybs + iqs + 0] = v.x; dst_row[iybs + iqs + y_offset] = v.y; @@ -40,7 +44,7 @@ static __global__ void k_get_rows( template static __global__ void k_get_rows_float( const src0_t * src0, const int32_t * src1, dst_t * dst, - int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ + int64_t ne00, int64_t ne01, /*int64_t ne02, int64_t ne03,*/ /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ /*size_t s0,*/ size_t s1, size_t s2, size_t s3, /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, @@ -56,11 +60,10 @@ static __global__ void k_get_rows_float( } const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; - dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03); - dst_row[i00] = src0_row[i00]; + dst_row[i00] = i01 >= 0 && i01 < ne01 ? dst_t(src0_row[i00]) : dst_t(0); } template @@ -88,7 +91,7 @@ static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, gg k_get_rows<<>>( src0_dd, src1_dd, dst_dd, - ne00, /*ne01, ne02, ne03,*/ + ne00, ne01, /*ne02, ne03,*/ /*ne10, ne11,*/ ne12, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, @@ -120,7 +123,7 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr k_get_rows_float<<>>( src0_dd, src1_dd, dst_dd, - ne00, /*ne01, ne02, ne03,*/ + ne00, ne01, /*ne02, ne03,*/ /*ne10, ne11,*/ ne12, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 7ba5e1ad..31fbc57e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3875,6 +3875,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "ARANGE", "TIMESTEP_EMBEDDING", "ARGSORT", + "ARGSORT_THRESH", "LEAKY_RELU", "SOFTCAP", "SOFT_CAP_MAX", @@ -3905,7 +3906,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); +static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3969,6 +3970,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "arange(start, stop, step)", "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", + "argsort_thresh(x)", "leaky_relu(x)", "k2*tanh(k1*x)", "soft_max(k2*tanh(k1*x))", @@ -3999,7 +4001,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); +static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -8497,6 +8499,27 @@ struct ggml_tensor * ggml_argsort( return result; } +// ggml_argsort + +struct ggml_tensor * ggml_argsort_thresh( + struct ggml_context * ctx, + struct ggml_tensor * a, + int min_entries, + float thresh) { + bool is_node = false; + + //printf("%s: min_entries = %d, thresh = %g\n", __func__, min_entries, (double)thresh); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne); + + ggml_set_op_params_i32(result, 0, (int32_t) min_entries); + ggml_set_op_params_f32(result, 1, thresh); + + result->op = GGML_OP_ARGSORT_THRESH; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} // ggml_top_k @@ -8516,6 +8539,32 @@ struct ggml_tensor * ggml_top_k( return result; } +// ggml_top_k_thresh + +struct ggml_tensor * ggml_top_k_thresh( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k, + int min_entries, + float thresh) { + GGML_ASSERT(a->ne[0] >= k); + + //printf("%s: k = %d, min_entries = %d, thresh = %g\n", __func__, k, min_entries, (double)thresh); + struct ggml_tensor * result; + if (min_entries <= 0 || thresh <= 0) { + result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC); + } else { + result = ggml_argsort_thresh(ctx, a, min_entries, thresh); + } + + result = ggml_view_4d(ctx, result, + k, result->ne[1], result->ne[2], result->ne[3], + result->nb[1], result->nb[2], result->nb[3], + 0); + + return result; +} + // ggml_flash_attn_ext struct ggml_tensor * ggml_flash_attn_ext( @@ -14485,7 +14534,8 @@ static void ggml_compute_forward_mul_mat_id( for (int id = 0; id < n_ids; ++id) { const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]); - assert(i02 >= 0 && i02 < n_as); + if (i02 < 0 || i02 >= n_as) continue; + //assert(i02 >= 0 && i02 < n_as); MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1}; matrix_row_counts[i02] += 1; @@ -14737,7 +14787,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate( for (int id = 0; id < n_ids; ++id) { const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]); - assert(i02 >= 0 && i02 < n_as); + if (i02 < 0 || i02 >= n_as) continue; + //assert(i02 >= 0 && i02 < n_as); MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1}; matrix_row_counts[i02] += 1; @@ -15580,7 +15631,11 @@ static void ggml_compute_forward_get_rows_q( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - assert(i01 >= 0 && i01 < ne01); + if (i01 < 0 || i01 >= ne01) { + memset((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3, 0, nc*sizeof(float)); + continue; + } + //assert(i01 >= 0 && i01 < ne01); dequantize_row_q( (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), @@ -17667,6 +17722,75 @@ static void ggml_compute_forward_argsort( } } +// ggml_compute_forward_argsort_thresh + +static void ggml_compute_forward_argsort_thresh_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(nb0 == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = ggml_nrows(src0); + + int min_entries = ggml_get_op_params_i32(dst, 0); + float thresh = ggml_get_op_params_f32(dst, 1); + + //if (ith == 0) printf("%s: min_entries = %d, thresh = %g\n", __func__, min_entries, (double)thresh); + + for (int64_t i = ith; i < nr; i += nth) { + int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); + const float * src_data = (float *)((char *) src0->data + i*nb01); + + for (int64_t j = 0; j < ne0; j++) { + dst_data[j] = j; + } + + // C doesn't have a functional sort, so we do a bubble sort instead + for (int64_t j = 0; j < ne0; j++) { + for (int64_t k = j + 1; k < ne0; k++) { + if (src_data[dst_data[j]] < src_data[dst_data[k]]) { + int32_t tmp = dst_data[j]; + dst_data[j] = dst_data[k]; + dst_data[k] = tmp; + } + } + } + float max_value = src_data[dst_data[0]]; + //printf("Row %ld: max_value is %g, next is %g\n", i, (double)max_value, (double)src_data[dst_data[1]]); + for (int j = min_entries; j < ne0; ++j) { + if (src_data[dst_data[j]] < max_value*thresh) { + //printf(" row %ld: turning off expert %d(%d) with value %g\n", i, j, dst_data[j], (double)src_data[dst_data[j]]); + dst_data[j] = -1; + } + } + } +} + +static void ggml_compute_forward_argsort_thresh( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_argsort_thresh_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_flash_attn_ext static void ggml_compute_forward_flash_attn_ext_f16( @@ -19476,6 +19600,10 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_argsort(params, tensor); } break; + case GGML_OP_ARGSORT_THRESH: + { + ggml_compute_forward_argsort_thresh(params, tensor); + } break; case GGML_OP_LEAKY_RELU: { ggml_compute_forward_leaky_relu(params, tensor); @@ -20461,6 +20589,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ABORT("fatal error"); // TODO: not implemented } + case GGML_OP_ARGSORT_THRESH: + { + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_LEAKY_RELU: { GGML_ABORT("fatal error"); // TODO: not implemented @@ -21181,6 +21313,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_ARGSORT: + case GGML_OP_ARGSORT_THRESH: case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_SSM_CONV: diff --git a/include/llama.h b/include/llama.h index bb43aebc..38a12744 100644 --- a/include/llama.h +++ b/include/llama.h @@ -386,6 +386,8 @@ extern "C" { int mla_attn; // whether to use MLA attention [EXPERIMENTAL] int attn_max_batch; // maximum batch size for attention computations [EXPERIMENTAL] bool fused_moe_up_gate; // whether to use fused MoE up/down op [EXPERIMENTAL] + int min_experts; + float thresh_experts; // Abort callback // if it returns true, execution of llama_decode() will be aborted diff --git a/src/llama.cpp b/src/llama.cpp index 0dcc78dc..3a8b54ca 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2513,6 +2513,8 @@ struct llama_cparams { int mla_attn; int attn_max_batch; bool fused_moe_up_gate; + int min_experts; + float thresh_experts; enum llama_pooling_type pooling_type; @@ -8631,7 +8633,8 @@ llm_expert_gating_func_type gating_op, } // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens] + ggml_tensor * selected_experts = ggml_top_k_thresh(ctx, selection_probs, n_expert_used, + lctx.cparams.min_experts, lctx.cparams.thresh_experts); // [n_expert_used, n_tokens] cb(selected_experts->src[0], "ffn_moe_argsort", il); cb(selected_experts, "ffn_moe_topk", il); @@ -8974,6 +8977,8 @@ struct llm_build_context { const int mla_attn; const int attn_max_batch; const bool fused_moe_up_gate; + const int min_experts; + const float thresh_experts; const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; @@ -9027,6 +9032,8 @@ struct llm_build_context { mla_attn (cparams.mla_attn), attn_max_batch (cparams.attn_max_batch), fused_moe_up_gate(cparams.fused_moe_up_gate), + min_experts (cparams.min_experts), + thresh_experts (cparams.thresh_experts), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), @@ -17725,6 +17732,8 @@ struct llama_context_params llama_context_default_params() { /*.mla_attn =*/ 0, /*.attn_max_batch =*/ 0, /*.fused_moe_up_gate =*/ false, + /*.min_experts =*/ -1, + /*.thtesh_experts =*/ 0.0f, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -17926,6 +17935,9 @@ struct llama_context * llama_new_context_with_model( cparams.mla_attn = params.mla_attn; cparams.attn_max_batch = params.attn_max_batch; cparams.fused_moe_up_gate= params.fused_moe_up_gate; + cparams.min_experts = params.min_experts; + cparams.thresh_experts = params.thresh_experts; + cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; @@ -17995,6 +18007,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn); LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch); LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate); + LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); From a87e54db6ec2409284a55f029d4abe9e50990064 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 3 Mar 2025 15:17:51 +0200 Subject: [PATCH 029/132] Flash MLA (CPU only) (#240) * FlashMLA - it finally works (on the CPU) * FlashMLA: allow for f16 and bf16 cache in addition to q8_0 * It works with ggml FA, not with iqk FA * WIP * FlashMLA: it now works with iqk I had forgotten to divide the Q stride by sizeof(float) and that's why, very cobfusingly, it was working for TG but not for PP. * WIP * FlashMLA: that should be it for now --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-quants.c | 8 +- ggml/src/ggml.c | 6 +- ggml/src/iqk/iqk_mul_mat.cpp | 94 ++++++++++++++++++--- src/llama.cpp | 154 ++++++++++++++++++++--------------- 4 files changed, 183 insertions(+), 79 deletions(-) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index e8218e76..e39cf4aa 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -11,6 +11,7 @@ #include "ggml-quants.h" #include "ggml-impl.h" #if GGML_USE_IQK_MULMAT +#include "iqk/iqk_config.h" #include "iqk/iqk_mul_mat.h" #include "iqk/iqk_quantize.h" #endif @@ -5449,7 +5450,12 @@ void ggml_vec_dot_q6_0_q8_0(int n, float * restrict s, size_t bs, const void * r void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { #if GGML_USE_IQK_MULMAT - if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) { +#ifdef HAVE_FANCY_SIMD + enum ggml_type dot_type = GGML_TYPE_Q8_1_X4; +#else + enum ggml_type dot_type = GGML_TYPE_Q8_0_X4; +#endif + if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, dot_type, vy, by, s, bs, 0, 1)) { return; } #endif diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 31fbc57e..46e1a548 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -10451,7 +10451,7 @@ static void ggml_compute_forward_dup_bytes( ne00 == ne0 && nb00 == type_size && nb0 == type_size) { // copy by rows - const size_t rs = ne00 * type_size; + const size_t rs = ggml_row_size(src0->type, ne00); //ne00 * type_size; for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = ir0; i01 < ir1; i01++) { @@ -17871,6 +17871,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( #if GGML_USE_IQK_MULMAT if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) { + //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n", + // k->ne[0], k->ne[1], k->ne[2], q->ne[0], q->ne[1], q->ne[2], v->ne[0], v->ne[1], v->ne[2], mask->ne[0], mask->ne[1], mask->ne[2]); // I keep changing my mind what is the best strategy to split the threads when processing // multiple heads. This is my current thinking, the commented out code below was the previous. int ntg = nth/simple_gcd(neq2*neq3, nth); @@ -17906,8 +17908,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( } return; IQK_Flash_Attn_NotAvailable:; + printf("iqk_flash was rejected\n"); } - #endif const uint32_t n_head = neq2; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 0955f15d..1f18837c 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -15016,7 +15016,7 @@ template struct BaseHelper { BaseHelper(const char * data, int stride) : data(data), block(data), stride(stride) {} - inline void set_block(int k1) { block = data + k1*k_step*stride; } + //inline void set_block(int k1) { block = data + k1*k_step*stride; } inline void reset_block() { block = data; } inline void next_block() { block += k_step*stride; } inline const char * lblock(int l1) const { return block + l1*stride; } @@ -16038,9 +16038,9 @@ struct FlashQKV { } inline void normalize_and_store(const FlashMS& fms, int j, const qkv_cache_t * R, float * qkv) const { - //GGML_ASSERT(fms.S[j] > 0); - //auto norm = F16::set1(1/fms.S[j]); - auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f); + GGML_ASSERT(fms.S[j] > 0); + auto norm = F16::set1(1/fms.S[j]); + //auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f); for (int i = 0; i < D/F16::block_size; ++i) { auto r = F16::load(R + F16::block_size*i); F16::store(qkv + F16::block_size*i, F16::mul(norm, r)); @@ -16076,7 +16076,7 @@ struct FlashQKV { template struct FlashQKfp32 { - static_assert(D%F16::block_size == 0 && D <= 256); + static_assert(D%F16::block_size == 0 && D <= 576); static_assert(k_step%F16::block_size == 0); static_assert(q_step <= 4 || q_step%4 == 0); @@ -16571,8 +16571,8 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, // q_step-1 versions of these functions for us, which I though was too much with q_step = 8. template struct FlashAttn { - static_assert(Dk%F16::block_size == 0 && Dk <= 256); - static_assert(Dv%F16::block_size == 0 && Dv <= 256); + static_assert(Dk%F16::block_size == 0 && Dk <= 576); + static_assert(Dv%F16::block_size == 0 && Dv <= 512); static_assert(k_step%F16::block_size == 0); static_assert(q_step <= 4 || q_step%4 == 0); @@ -16665,7 +16665,8 @@ struct HelperBF16 final : public BaseHelper { template struct FlashQKbf16 { - static_assert(D%32 == 0 && D <= 256); + //static_assert(D%32 == 0 && D <= 256); + static_assert(D%32 == 0 && D <= 576); static_assert(k_step%32 == 0); static_assert(q_step <= 4 || q_step%4 == 0); @@ -16975,8 +16976,10 @@ struct FlashQKbf16 { template struct FlashAttnBF16 { - static_assert(Dk%32 == 0 && Dk <= 256); - static_assert(Dv%32 == 0 && Dv <= 256); + //static_assert(Dk%32 == 0 && Dk <= 256); + //static_assert(Dv%32 == 0 && Dv <= 256); + static_assert(Dk%32 == 0 && Dk <= 576); + static_assert(Dv%32 == 0 && Dv <= 512); static_assert(k_step%32 == 0); static_assert(q_step <= 4 || q_step%4 == 0); @@ -17216,6 +17219,66 @@ inline bool flash_attn_is_supported(ggml_type type) { #endif return false; } + +template +inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh, + int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, + const float * q, const char * mask, float scale, float softcap, float * qkv) { + if (nq1 % 8 == 0) { + FlashAttn<576, 512, 8, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv); + } else { + FlashAttn<576, 512, 1, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv); + } +} + +template +inline bool iqk_deepseek_helper(ggml_type type_k, + int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, + const float * q, const char * k, const char * v, const char * mask, + float scale, float softcap, float * qkv) { + if (type_k == GGML_TYPE_Q8_0) { + HelperQ80<576, step_k> kh((const char *)k, stride_k); + HelperQ80<512, step_k> vh((const char *)v, stride_v); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + return true; + } + if (type_k == GGML_TYPE_Q6_0) { + HelperQ60<576, step_k> kh((const char *)k, stride_k); + HelperQ60<512, step_k> vh((const char *)v, stride_v); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + return true; + } + if (type_k == GGML_TYPE_Q8_KV) { + HelperQ8KV<576, step_k> kh((const char *)k, stride_k); + HelperQ8KV<512, step_k> vh((const char *)v, stride_v); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + return true; + } + if (type_k == GGML_TYPE_F16) { + HelperF16<576, step_k> kh((const char *)k, stride_k); + HelperF16<512, step_k> vh((const char *)v, stride_v); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + return true; + } +#ifdef __AVX512BF16__ + if (type_k == GGML_TYPE_BF16) { + HelperBF16<576, step_k> kh((const char *)k, stride_k); + HelperBF16<512, step_k> vh((const char *)v, stride_v); + if (nq1 % 8 == 0) { + FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } else { + FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + } + return true; + } +#endif + return false; +} + } bool iqk_flash_attn_noalibi(int int_type_k, // type of k @@ -17237,10 +17300,19 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k float softcap, // if > 0, a "soft-cap" operation is applied before softmax float * qkv) { // v*softmax(scale*(k*q)) + if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32 + auto type_k = ggml_type(int_type_k); auto type_v = ggml_type(int_type_v); + + if (Dk == 576 && Dv == 512) { + GGML_ASSERT(type_k == type_v); + stride_q /= sizeof(float); // q stride as float + return iqk_deepseek_helper<32>(type_k, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, + q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv); + } + if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false; - if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32 if (Dk != Dv && Dk != 192 && Dv != 128) return false; if (Dv != 64 && Dv != 96 && Dv != 128 && Dv != 256) return false; if (Dk != 64 && Dk != 96 && Dk != 128 && Dk != 192 && Dv != 256) return false; diff --git a/src/llama.cpp b/src/llama.cpp index 3a8b54ca..5ac44055 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3168,7 +3168,7 @@ static bool llama_kv_cache_init( // DeepSeek MLA cache.kv_l.reserve(n_layer); - if (cparams.mla_attn == 1) { + if (cparams.mla_attn == 1 && !cparams.flash_attn) { cache.kvt_l.reserve(n_layer); } @@ -3201,14 +3201,20 @@ static bool llama_kv_cache_init( const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t kv_lora_rank = hparams.n_lora_kv; LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); - auto kv_type = cparams.mla_attn == 1 ? cache.type_k : cache.type_v; - ggml_tensor * kv = ggml_new_tensor_2d(ctx, kv_type, kv_lora_rank + n_embd_head_qk_rope, kv_size); - ggml_format_name(kv, "cache_kv_l%d", i); - cache.kv_l.push_back(kv); - if (cparams.mla_attn == 1) { - ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_v, kv_lora_rank*kv_size); - ggml_format_name(kvt, "cache_kvt_l%d", i); - cache.kvt_l.push_back(kvt); + if (cparams.flash_attn) { + ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size); + ggml_format_name(kv, "cache_kv_l%d", i); + cache.kv_l.push_back(kv); + } else { + auto kv_type = cparams.mla_attn == 1 ? cache.type_k : cache.type_v; + ggml_tensor * kv = ggml_new_tensor_2d(ctx, kv_type, kv_lora_rank + n_embd_head_qk_rope, kv_size); + ggml_format_name(kv, "cache_kv_l%d", i); + cache.kv_l.push_back(kv); + if (cparams.mla_attn == 1) { + ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_v, kv_lora_rank*kv_size); + ggml_format_name(kvt, "cache_kvt_l%d", i); + cache.kvt_l.push_back(kvt); + } } n_mla++; } @@ -13588,7 +13594,7 @@ struct llm_build_context { ggml_tensor * kv_cache_trans; - if (lctx.cparams.mla_attn == 1) { + if (lctx.cparams.mla_attn == 1 && !lctx.cparams.flash_attn) { ggml_tensor * kv_cache_trans_view = ggml_view_2d(ctx0, kv_self.kvt_l[il], n_tokens, kv_lora_rank, ggml_row_size(kv_self.kvt_l[il]->type, kv_self.size), ggml_row_size(kv_self.kvt_l[il]->type, kv_head)); cb(kv_cache_trans_view, "kv_cache_trans_view", il); @@ -13630,70 +13636,88 @@ struct llm_build_context { ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0); cb(q, "q", il); - if (lctx.cparams.mla_attn > 1) { + ggml_tensor * kqv_compressed; + + if (lctx.cparams.flash_attn) { ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); - cb(kv_cache, "kv_cache_lora", il); + cb(kv_cache_lora, "kv_cache_lora", il); - kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora)); - cb(kv_cache_trans, "kv_cache_trans", il); + //ggml_tensor * v = ggml_cont(ctx0, kv_cache_lora); + //kqv_compressed = ggml_flash_attn_ext(ctx0, q, kv_cache, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f); + + kqv_compressed = ggml_flash_attn_ext(ctx0, q, kv_cache, kv_cache_lora, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f); + cb(kqv_compressed, "kqv_compressed", il); + + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); + cb(kqv_compressed, "kqv_compressed_perm", il); } + else { + if (lctx.cparams.mla_attn > 1) { + ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], + kv_lora_rank, n_kv, + ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); + cb(kv_cache, "kv_cache_lora", il); - ggml_tensor * kqv_compressed; - auto kq_size = kv_cache->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); // K*Q in MiB - if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kq_size) { - if (!pp_opt) { - q = ggml_permute(ctx0, q, 0, 2, 1, 3); - cb(q, "q_perm", il); + kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora)); + cb(kv_cache_trans, "kv_cache_trans", il); } - ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q); - cb(kq, "kq", il); - - if (!pp_opt) { - kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); - cb(kq, "kq_perm", il); - } - - kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); - cb(kq, "kq_soft_max_ext", il); - - if (!pp_opt) { - kq = ggml_permute(ctx0, kq, 0, 2, 1, 3); - cb(kq, "kq_soft_max_ext_perm", il); - } - - kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq); - cb(kqv_compressed, "kqv_compressed", il); - - if (!pp_opt) { - kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); - cb(kqv_compressed, "kqv_compressed_perm", il); - } - - } else { - - int n_step = (kq_size + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch; - n_step = std::min(n_step, int(q->ne[2])); - int n_per_step = (q->ne[2] + n_step - 1)/n_step; - - //printf("kq size would be %ld MiB -> splitting kqv computation into %d steps\n", kq_size, n_step); - - for (int i_head = 0; i_head < q->ne[2]; i_head += n_per_step) { - int this_ne12 = i_head + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i_head; - ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i_head); - ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i); - kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias); - ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i); - if (i_head == 0) { - kqv_compressed = kqv_i; - } else { - kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2); + auto kq_size = kv_cache->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); // K*Q in MiB + if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kq_size) { + if (!pp_opt) { + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + cb(q, "q_perm", il); } - ggml_build_forward_expand(gf, kqv_compressed); + + ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q); + cb(kq, "kq", il); + + if (!pp_opt) { + kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); + cb(kq, "kq_perm", il); + } + + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + if (!pp_opt) { + kq = ggml_permute(ctx0, kq, 0, 2, 1, 3); + cb(kq, "kq_soft_max_ext_perm", il); + } + + kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq); + cb(kqv_compressed, "kqv_compressed", il); + + if (!pp_opt) { + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); + cb(kqv_compressed, "kqv_compressed_perm", il); + } + + } else { + + int n_step = (kq_size + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch; + n_step = std::min(n_step, int(q->ne[2])); + int n_per_step = (q->ne[2] + n_step - 1)/n_step; + + //printf("kq size would be %ld MiB -> splitting kqv computation into %d steps\n", kq_size, n_step); + + for (int i_head = 0; i_head < q->ne[2]; i_head += n_per_step) { + int this_ne12 = i_head + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i_head; + ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i_head); + ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i); + kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i); + if (i_head == 0) { + kqv_compressed = kqv_i; + } else { + kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2); + } + ggml_build_forward_expand(gf, kqv_compressed); + } + cb(kqv_compressed, "kqv_compressed", il); } - cb(kqv_compressed, "kqv_compressed", il); } struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, @@ -18226,7 +18250,7 @@ struct llama_context * llama_new_context_with_model( } if (memory_size_kv + memory_size_kvt > 0) { - if (cparams.mla_attn == 1) { + if (cparams.mla_attn == 1 && !cparams.flash_attn) { LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, c^KV (%s): %7.2f MiB, kv^T (%s): %7.2f MiB\n", __func__, (float)(memory_size_kv + memory_size_kvt) / (1024.0f * 1024.0f), ggml_type_name(kv_type), (float)memory_size_kv / (1024.0f * 1024.0f), From 7bdbf99bbdbfe46b01f7783a7c98a30a1558e2c3 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 5 Mar 2025 07:27:49 +0200 Subject: [PATCH 030/132] DeepSeek CUDA Flash Attention (#241) * WIP CUDA FA with Dk != Dv * WIP * CUDA FA WIP - It actually works! No TG yet, but for PP I can run FA with fp16 cache and it gets the same answer. * CUDA FA WIP - it now works for Q8_0 + Q8_0 for KV cache * CUDA FA WIP - TG, not working yet. * CUDA FA with Dk != Dv: it works now for DeepSeek --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda.cu | 4 + ggml/src/ggml-cuda/fattn-common.cuh | 108 +++++++++-------- ggml/src/ggml-cuda/fattn-tile-f16.cu | 4 +- ggml/src/ggml-cuda/fattn-tile-f32.cu | 60 +++++----- ggml/src/ggml-cuda/fattn-vec-f16.cuh | 110 ++++++++++-------- ggml/src/ggml-cuda/fattn-vec-f32.cuh | 110 ++++++++++-------- ggml/src/ggml-cuda/fattn-wmma-f16.cuh | 109 +++++++++-------- ggml/src/ggml-cuda/fattn.cu | 110 +++++++++++++----- .../fattn-vec-f16-instance-hs192-f16-f16.cu | 5 + .../fattn-vec-f16-instance-hs192-q8_0-q8_0.cu | 5 + .../fattn-vec-f32-instance-hs192-f16-f16.cu | 5 + .../fattn-vec-f32-instance-hs192-q8_0-q8_0.cu | 5 + .../fattn-wmma-f16-instance-kqfloat-cpb16.cu | 2 + .../fattn-wmma-f16-instance-kqfloat-cpb32.cu | 2 + .../fattn-wmma-f16-instance-kqhalf-cpb16.cu | 2 + .../fattn-wmma-f16-instance-kqhalf-cpb32.cu | 2 + .../fattn-wmma-f16-instance-kqhalf-cpb8.cu | 2 + src/llama.cpp | 6 +- 18 files changed, 392 insertions(+), 259 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs192-f16-f16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs192-q8_0-q8_0.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs192-f16-f16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs192-q8_0-q8_0.cu diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 85df0694..410c6406 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3275,6 +3275,10 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons if (op->src[0]->ne[0] == 128) { return true; } + if (op->src[1]->ne[0] == 192 && op->src[2]->ne[0] == 128) { + return (op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) || + (op->src[1]->type == GGML_TYPE_Q8_0 && op->src[2]->type == GGML_TYPE_Q8_0); + } if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) { return true; } diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 0a664dbd..a46f03e5 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -52,7 +52,7 @@ typedef half (*vec_dot_KQ_f16_t)( typedef float (*vec_dot_KQ_f32_t)( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -62,7 +62,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -92,7 +92,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( return sum; } -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -102,7 +102,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -142,7 +142,7 @@ static __device__ __forceinline__ int get_one_int_from_table_16(const int & q4) return *((const int *) &val0_8); } -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -152,7 +152,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -179,7 +179,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl( return sum; } -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -189,7 +189,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -226,7 +226,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( return sum; } -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -236,7 +236,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -277,7 +277,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( return sum; } -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q6_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -287,7 +287,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q6_0( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; @@ -320,7 +320,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q6_0( return sum; } -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -330,7 +330,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( T sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_0; @@ -353,7 +353,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( return sum; } -template +template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { @@ -368,7 +368,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( half2 sum2 = make_half2(0.0f, 0.0f); #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/2; k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const half2 K_ik = K_h2[k_KQ]; @@ -384,7 +384,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( float sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk/2; k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; const half2 K_ik = K_h2[k_KQ]; @@ -603,29 +603,29 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v return x[i]; } -template +template constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) { - return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : - type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : - type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl : - type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : - type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : - type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0 : - type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : - type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : + return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : + type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : + type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl : + type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : + type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : + type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0 : + type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : + type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : nullptr; } -template +template constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) { - return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : - type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : - type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl : - type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : - type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : - type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0 : - type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : - type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : + return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : + type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : + type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl : + type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : + type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : + type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0 : + type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : + type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : nullptr; } @@ -653,20 +653,20 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { nullptr; } -template // D == head size +template // Dv == V head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(D, 1) +__launch_bounds__(Dv, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_combine_results( const float * __restrict__ VKQ_parts, const float2 * __restrict__ VKQ_meta, float * __restrict__ dst) { - VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; - VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; - dst += D * gridDim.y*blockIdx.x; + VKQ_parts += parallel_blocks*Dv * gridDim.y*blockIdx.x; + VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; + dst += Dv * gridDim.y*blockIdx.x; const int tid = threadIdx.x; - __builtin_assume(tid < D); + __builtin_assume(tid < Dv); __shared__ float2 meta[parallel_blocks]; if (tid < 2*parallel_blocks) { @@ -690,20 +690,20 @@ static __global__ void flash_attn_combine_results( const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); *((uint32_t *) &KQ_max_scale) &= ftz_mask; - VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; + VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*Dv + blockIdx.y*Dv + tid]; VKQ_denominator += KQ_max_scale * meta[l].y; } - dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; + dst[blockIdx.y*Dv + tid] = VKQ_numerator / VKQ_denominator; } -static void on_no_fattn_vec_case(const int D) { - if (D == 64) { +static void on_no_fattn_vec_case(const int Dk, const int Dv) { + if (Dk == 64 && Dv == 64) { fprintf(stderr, "Unsupported KV type combination for head_size 64.\n"); fprintf(stderr, "By default only f16 KV cache is supported.\n"); fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n"); GGML_ABORT("fatal error"); - } else if (D == 128) { + } else if (Dk == 128 && Dv == 128) { fprintf(stderr, "Unsupported KV type combination for head_size 128.\n"); fprintf(stderr, "Supported combinations:\n"); fprintf(stderr, " - K == q4_0, V == q4_0, 4.5 BPV\n"); @@ -715,14 +715,22 @@ static void on_no_fattn_vec_case(const int D) { fprintf(stderr, " - K == f16, V == f16, 16.0 BPV\n"); fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q8_0, and f16.\n"); GGML_ABORT("fatal error"); + } + else if (Dk == 192 && Dv == 128) { + fprintf(stderr, "Unsupported KV type combination for head_sizes 192 / 128\n"); + // TODO: add what is supported + } + else if (Dk == 576 && Dv == 512) { + fprintf(stderr, "Unsupported KV type combination for head_sizes 576 / 512\n"); + // TODO: add what is supported } else { - fprintf(stderr, "Unsupported KV type combination for head_size 256.\n"); + fprintf(stderr, "Unsupported KV type combination for head_sizes %d, %d.\n", Dk, Dv); fprintf(stderr, "Only f16 is supported.\n"); GGML_ABORT("fatal error"); } } -template +template void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V @@ -838,11 +846,11 @@ void launch_fattn( return; } - const dim3 block_dim_combine(D, 1, 1); + const dim3 block_dim_combine(Dv, 1, 1); const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); const int shmem_combine = 0; - flash_attn_combine_results + flash_attn_combine_results <<>> (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); CUDA_CHECK(cudaGetLastError()); diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index d1bbf01f..bf2a4521 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -291,13 +291,13 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int D = 64; constexpr int nwarps = 8; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 25908d7a..28846561 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -4,7 +4,7 @@ #define FATTN_KQ_STRIDE_TILE_F32 32 -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -44,8 +44,9 @@ static __global__ void flash_attn_tile_ext_f32( const int ne1, const int ne2, const int ne3) { + static_assert(Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)); // Skip unused kernel variants for faster compilation: - if (use_softcap && !(D == 128 || D == 256)) { + if (use_softcap && !(Dk == 128 || Dk == 256)) { NO_DEVICE_CODE; return; } @@ -61,15 +62,22 @@ static __global__ void flash_attn_tile_ext_f32( const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape const half * maskh = (const half *) mask + ne11*ic0; - const int stride_KV2 = nb11 / sizeof(half2); + const int stride_K2 = nb11 / sizeof(half2); + const int stride_V2 = nb12 / sizeof(half2); const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); + // TODO: is it Dk or Dv or both that need to be multiple of 2*WARP_SIZE ? + // let's assume it is is both. + static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64."); + static_assert(Dv % (2*WARP_SIZE) == 0, "Dv not divisible by 2*WARP_SIZE == 64."); + + constexpr int Dkv = Dk < Dv ? Dv : Dk; // let's use this when we don't understand if it is Dk or Dv __shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32]; - __shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1]; // Pad D to avoid memory bank conflicts. + // This is being used to store either K or V data. Hence we need max(Dk, Dv) as the dimension + __shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][Dkv + 1]; // Pad D to avoid memory bank conflicts. float2 * KV_tmp2 = (float2 *) KV_tmp; float kqmax[ncols/nwarps]; @@ -79,16 +87,16 @@ static __global__ void flash_attn_tile_ext_f32( } float kqsum[ncols/nwarps] = {0.0f}; - float2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}}; + float2 VKQ[ncols/nwarps][(Dv/2)/WARP_SIZE] = {{{0.0f, 0.0f}}}; // Convert Q to half2 and store in registers: - __shared__ float Q_f[ncols][D]; + __shared__ float Q_f[ncols][Dk]; #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) { + for (int i0 = 0; i0 < Dk; i0 += 2*WARP_SIZE) { float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f); Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale; Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale; @@ -112,8 +120,8 @@ static __global__ void flash_attn_tile_ext_f32( const int i_KQ = i_KQ_0 + threadIdx.y; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) { - const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x]; + for (int k_KQ_0 = 0; k_KQ_0 < Dk; k_KQ_0 += 2*WARP_SIZE) { + const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_K2 + k_KQ_0/2 + threadIdx.x]; KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp); KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp); } @@ -124,7 +132,7 @@ static __global__ void flash_attn_tile_ext_f32( float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0.0f}}; #pragma unroll - for (int k_KQ = 0; k_KQ < D; ++k_KQ) { + for (int k_KQ = 0; k_KQ < Dk; ++k_KQ) { float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE]; float Q_k[ncols/nwarps]; @@ -193,7 +201,7 @@ static __global__ void flash_attn_tile_ext_f32( kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add; #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale; VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale; } @@ -206,11 +214,11 @@ static __global__ void flash_attn_tile_ext_f32( const int k = k0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - KV_tmp2[k*(D/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]); - KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]); + KV_tmp2[k*(Dv/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_V2 + i]); + KV_tmp2[k*(Dv/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_V2 + i]); } } @@ -218,14 +226,14 @@ static __global__ void flash_attn_tile_ext_f32( #pragma unroll for (int k = 0; k < FATTN_KQ_STRIDE_TILE_F32; ++k) { - float2 V_k[(D/2)/WARP_SIZE]; + float2 V_k[(Dv/2)/WARP_SIZE]; float KQ_k[ncols/nwarps]; #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - V_k[i0/WARP_SIZE] = KV_tmp2[k*(D/2) + i]; + V_k[i0/WARP_SIZE] = KV_tmp2[k*(Dv/2) + i]; } #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { @@ -235,7 +243,7 @@ static __global__ void flash_attn_tile_ext_f32( } #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x*KQ_k[j0/nwarps]; @@ -259,7 +267,7 @@ static __global__ void flash_attn_tile_ext_f32( kqsum_j = warp_reduce_sum(kqsum_j); #pragma unroll - for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) { + for (int i00 = 0; i00 < Dv; i00 += 2*WARP_SIZE) { const int i0 = i00 + 2*threadIdx.x; float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)]; @@ -268,8 +276,8 @@ static __global__ void flash_attn_tile_ext_f32( dst_val.y /= kqsum_j; } const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x; - dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y; + dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + i0 + 0] = dst_val.x; + dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + i0 + 1] = dst_val.y; } if (parallel_blocks != 1 && threadIdx.x == 0) { @@ -285,14 +293,14 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 7f14e78b..b6ba07e4 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -1,9 +1,9 @@ #include "common.cuh" #include "fattn-common.cuh" -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(D, 1) +__launch_bounds__(Dk, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_vec_ext_f16( const char * __restrict__ Q, @@ -43,14 +43,15 @@ static __global__ void flash_attn_vec_ext_f16( const int ne3) { #ifdef FP16_AVAILABLE // Skip unused kernel variants for faster compilation: - if (use_softcap && !(D == 128 || D == 256)) { + if constexpr (Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)) { + if (use_softcap && !(Dk == 128 || Dk == 256)) { NO_DEVICE_CODE; return; } //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16(type_K); + constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16(type_K); constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V); @@ -67,12 +68,13 @@ static __global__ void flash_attn_vec_ext_f16( const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); - constexpr int nwarps = D / WARP_SIZE; + static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64."); + static_assert(Dv % (2*WARP_SIZE) == 0, "Dv not divisible by 2*WARP_SIZE == 64."); + constexpr int nwarps = Dk / WARP_SIZE; const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; - __builtin_assume(tid < D); + __builtin_assume(tid < Dk); - __shared__ half KQ[ncols*D]; + __shared__ half KQ[ncols*Dk]; half2 * KQ2 = (half2 *) KQ; half kqmax[ncols]; @@ -94,9 +96,9 @@ static __global__ void flash_attn_vec_ext_f16( __syncthreads(); // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers: - half2 Q_h2[ncols][D/(2*WARP_SIZE)]; - int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D/(sizeof(int)*QK8_1)]; - half2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1]; + half2 Q_h2[ncols][Dk/(2*WARP_SIZE)]; + int Q_i32[ncols][Dk/(sizeof(int)*QK8_1) == 0 ? 1 : Dk/(sizeof(int)*QK8_1)]; + half2 Q_ds[ncols][Dk/QK8_1 == 0 ? 1 : Dk/QK8_1]; if (Q_q8_1) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { @@ -107,18 +109,18 @@ static __global__ void flash_attn_vec_ext_f16( } // Reuse KQ as temporary storage for converting Q to q8_1: - int * tmp_q_i32 = (int *) &KQ[j*D]; - half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); + int * tmp_q_i32 = (int *) &KQ[j*Dk]; + half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + Dk/sizeof(int)); // Set memory to zero if out of bounds: if (ncols > 2 && ic0 + j >= ne01) { #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; tmp_q_i32[i] = 0; } - if (threadIdx.x < D/QK8_1) { + if (threadIdx.x < Dk/QK8_1) { tmp_q_ds[threadIdx.x] = make_half2(0.0f, 0.0f); } continue; @@ -126,7 +128,7 @@ static __global__ void flash_attn_vec_ext_f16( const float * Q_f = (const float *) (Q + j*nb01); #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) { quantize_q8_1_to_shared(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); } } @@ -135,11 +137,11 @@ static __global__ void flash_attn_vec_ext_f16( #pragma unroll for (int j = 0; j < ncols; ++j) { - int * tmp_q_i32 = (int *) &KQ[j*D]; - half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); + int * tmp_q_i32 = (int *) &KQ[j*Dk]; + half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + Dk/sizeof(int)); #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; @@ -154,7 +156,7 @@ static __global__ void flash_attn_vec_ext_f16( const float2 * Q_f2_j = (const float2 *) (Q + j*nb01); #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f); @@ -166,13 +168,13 @@ static __global__ void flash_attn_vec_ext_f16( #pragma unroll for (int j = 0; j < ncols; ++j) { - KQ[j*D + tid] = -HALF_MAX_HALF; + KQ[j*Dk + tid] = -HALF_MAX_HALF; } half2 VKQ[ncols] = {{0.0f, 0.0f}}; - const int k_start = parallel_blocks == 1 ? 0 : ip*D; - for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { + const int k_start = parallel_blocks == 1 ? 0 : ip*Dk; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*Dk) { // Calculate KQ tile and keep track of new maximum KQ values: // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, @@ -186,10 +188,10 @@ static __global__ void flash_attn_vec_ext_f16( } #pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { + for (int i_KQ_0 = 0; i_KQ_0 < Dk; i_KQ_0 += nwarps) { const int i_KQ = i_KQ_0 + threadIdx.y; - if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { + if ((i_KQ_0 + nwarps > Dk && i_KQ >= Dk) || (FATTN_KQ_STRIDE % Dk != 0 && k_VKQ_0 + i_KQ >= ne11)) { break; } @@ -209,7 +211,7 @@ static __global__ void flash_attn_vec_ext_f16( } if (threadIdx.x == 0) { - KQ[j*D + i_KQ] = sum; + KQ[j*Dk + i_KQ] = sum; } } } @@ -234,9 +236,9 @@ static __global__ void flash_attn_vec_ext_f16( const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j); kqmax[j] = kqmax_new_j; - const half val = hexp(KQ[j*D + tid] - kqmax[j]); + const half val = hexp(KQ[j*Dk + tid] - kqmax[j]); kqsum[j] = kqsum[j]*KQ_max_scale + val; - KQ[j*D + tid] = val; + KQ[j*Dk + tid] = val; VKQ[j] *= __half2half2(KQ_max_scale); } @@ -244,8 +246,8 @@ static __global__ void flash_attn_vec_ext_f16( __syncthreads(); #pragma unroll - for (int k0 = 0; k0 < D; k0 += 2) { - if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) { + for (int k0 = 0; k0 < Dv; k0 += 2) { + if (FATTN_KQ_STRIDE % Dv != 0 && k_VKQ_0 + k0 >= ne11) { break; } @@ -254,7 +256,7 @@ static __global__ void flash_attn_vec_ext_f16( reinterpret_cast(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid); #pragma unroll for (int j = 0; j < ncols; ++j) { - VKQ[j] += V_k*KQ2[j*(D/2) + k0/2]; + VKQ[j] += V_k*KQ2[j*(Dk/2) + k0/2]; } } @@ -285,27 +287,28 @@ static __global__ void flash_attn_vec_ext_f16( dst_val /= kqsum[j_VKQ]; } const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; + dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + tid] = dst_val; } if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); } + } #else NO_DEVICE_CODE; #endif // FP16_AVAILABLE } -template +template void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; - constexpr bool need_f16_K = D != 128; - constexpr bool need_f16_V = D != 128 && D != 64; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); + constexpr int nwarps = Dk/WARP_SIZE; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; + constexpr bool need_f16_K = Dk != 128 && Dk != 192; + constexpr bool need_f16_V = Dv != 128 && Dv != 64; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); } -template +template void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; @@ -325,9 +328,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 1; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } return; } @@ -336,9 +339,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 2; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } return; } @@ -347,9 +350,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 4; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } return; } @@ -358,9 +361,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 8; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } return; } @@ -368,15 +371,19 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } } #define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \ template void ggml_cuda_flash_attn_ext_vec_f16_case \ - (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +#define DECL_FATTN_VEC_F16_CASE_DKDV(Dk, Dv, type_K, type_V) \ + template void ggml_cuda_flash_attn_ext_vec_f16_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0); extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1); @@ -435,3 +442,6 @@ extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); + +extern DECL_FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16); +extern DECL_FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 1aa88272..404afe2e 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -1,9 +1,9 @@ #include "common.cuh" #include "fattn-common.cuh" -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(D, 1) +__launch_bounds__(Dk, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_vec_ext_f32( const char * __restrict__ Q, @@ -42,14 +42,15 @@ static __global__ void flash_attn_vec_ext_f32( const int ne2, const int ne3) { // Skip unused kernel variants for faster compilation: - if (use_softcap && !(D == 128 || D == 256)) { + if constexpr (Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)) { + if (use_softcap && !(Dk == 128 || Dk == 256)) { NO_DEVICE_CODE; return; } //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32(type_K); + constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32(type_K); constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V); @@ -64,15 +65,16 @@ static __global__ void flash_attn_vec_ext_f32( const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); - constexpr int nwarps = D / WARP_SIZE; + static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64."); + static_assert(Dv % (2*WARP_SIZE) == 0, "Dv not divisible by 2*WARP_SIZE == 64."); + constexpr int nwarps = Dk / WARP_SIZE; const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; - __builtin_assume(tid < D); + __builtin_assume(tid < Dk); - __shared__ float KQ[ncols*D]; + __shared__ float KQ[ncols*Dk]; #pragma unroll for (int j = 0; j < ncols; ++j) { - KQ[j*D + tid] = -FLT_MAX/2.0f; + KQ[j*Dk + tid] = -FLT_MAX/2.0f; } float kqmax[ncols]; @@ -94,9 +96,9 @@ static __global__ void flash_attn_vec_ext_f32( __syncthreads(); // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: - float2 Q_f2[ncols][D/(2*WARP_SIZE)]; - int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D >= D/(sizeof(int)*QK8_1)]; - float2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1]; + float2 Q_f2[ncols][Dk/(2*WARP_SIZE)]; + int Q_i32[ncols][Dk/(sizeof(int)*QK8_1) == 0 ? 1 : Dk >= Dk/(sizeof(int)*QK8_1)]; + float2 Q_ds[ncols][Dk/QK8_1 == 0 ? 1 : Dk/QK8_1]; if (Q_q8_1) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { @@ -107,18 +109,18 @@ static __global__ void flash_attn_vec_ext_f32( } // Reuse KQ as temporary storage for converting Q to q8_1: - int * tmp_q_i32 = (int *) &KQ[j*D]; - float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); + int * tmp_q_i32 = (int *) &KQ[j*Dk]; + float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + Dk/sizeof(int)); // Set memory to zero if out of bounds: if (ncols > 2 && ic0 + j >= ne01) { #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; tmp_q_i32[i] = 0; } - if (threadIdx.x < D/QK8_1) { + if (threadIdx.x < Dk/QK8_1) { tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f); } continue; @@ -126,7 +128,7 @@ static __global__ void flash_attn_vec_ext_f32( const float * Q_f = (const float *) (Q + j*nb01); #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) { quantize_q8_1_to_shared(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); } } @@ -135,11 +137,11 @@ static __global__ void flash_attn_vec_ext_f32( #pragma unroll for (int j = 0; j < ncols; ++j) { - int * tmp_q_i32 = (int *) &KQ[j*D]; - float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); + int * tmp_q_i32 = (int *) &KQ[j*Dk]; + float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + Dk/sizeof(int)); #pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/sizeof(int); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; @@ -153,7 +155,7 @@ static __global__ void flash_attn_vec_ext_f32( for (int j = 0; j < ncols; ++j) { const float2 * Q_f2_j = (const float2 *) (Q + j*nb01); #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f); @@ -165,8 +167,8 @@ static __global__ void flash_attn_vec_ext_f32( float VKQ[ncols] = {0.0f}; - const int k_start = parallel_blocks == 1 ? 0 : ip*D; - for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { + const int k_start = parallel_blocks == 1 ? 0 : ip*Dk; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*Dk) { // Calculate KQ tile and keep track of new maximum KQ values: float kqmax_new_arr[ncols]; @@ -176,10 +178,10 @@ static __global__ void flash_attn_vec_ext_f32( } #pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { + for (int i_KQ_0 = 0; i_KQ_0 < Dk; i_KQ_0 += nwarps) { const int i_KQ = i_KQ_0 + threadIdx.y; - if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { + if ((i_KQ_0 + nwarps > Dk && i_KQ >= Dk) || (FATTN_KQ_STRIDE % Dk != 0 && k_VKQ_0 + i_KQ >= ne11)) { break; } @@ -195,7 +197,7 @@ static __global__ void flash_attn_vec_ext_f32( kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum); if (threadIdx.x == 0) { - KQ[j*D + i_KQ] = sum; + KQ[j*Dk + i_KQ] = sum; } } } @@ -220,9 +222,9 @@ static __global__ void flash_attn_vec_ext_f32( const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j); kqmax[j] = kqmax_new_j; - const float val = expf(KQ[j*D + tid] - kqmax[j]); + const float val = expf(KQ[j*Dk + tid] - kqmax[j]); kqsum[j] = kqsum[j]*KQ_max_scale + val; - KQ[j*D + tid] = val; + KQ[j*Dk + tid] = val; VKQ[j] *= KQ_max_scale; } @@ -230,15 +232,15 @@ static __global__ void flash_attn_vec_ext_f32( __syncthreads(); #pragma unroll - for (int k = 0; k < D; ++k) { - if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k >= ne11) { + for (int k = 0; k < Dv; ++k) { + if (FATTN_KQ_STRIDE % Dv != 0 && k_VKQ_0 + k >= ne11) { break; } const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid); #pragma unroll for (int j = 0; j < ncols; ++j) { - VKQ[j] += V_ki*KQ[j*D + k]; + VKQ[j] += V_ki*KQ[j*Dk + k]; } } @@ -269,24 +271,25 @@ static __global__ void flash_attn_vec_ext_f32( dst_val /= kqsum[j_VKQ]; } const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; + dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + tid] = dst_val; } if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); } + } } -template +template void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; - constexpr bool need_f16_K = D != 128; - constexpr bool need_f16_V = D != 128 && D != 64; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); + constexpr int nwarps = Dk/WARP_SIZE; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; + constexpr bool need_f16_K = Dk != 128 && Dk != 192; + constexpr bool need_f16_V = Dv != 128 && Dv != 64; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); } -template +template void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; @@ -303,9 +306,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 1; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } return; } @@ -314,9 +317,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 2; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } return; } @@ -325,9 +328,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 4; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } return; } @@ -336,9 +339,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 8; constexpr int parallel_blocks = 4; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } return; } @@ -346,15 +349,19 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1; if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } } #define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \ template void ggml_cuda_flash_attn_ext_vec_f32_case \ - (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +#define DECL_FATTN_VEC_F32_CASE_DKDV(Dk, Dv, type_K, type_V) \ + template void ggml_cuda_flash_attn_ext_vec_f32_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0); extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1); @@ -406,3 +413,6 @@ extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); + +extern DECL_FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16); +extern DECL_FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index efe78a2f..c5ffc7d1 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -5,8 +5,8 @@ #include #endif // FP16_MMA_AVAILABLE -// D == head size, VKQ_stride == num VKQ rows calculated in parallel: -template +// Dk == K head size, Dv = V head size, VKQ_stride == num VKQ rows calculated in parallel: +template #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -47,8 +47,9 @@ static __global__ void flash_attn_ext_f16( const int ne2, const int ne3) { #ifdef FP16_MMA_AVAILABLE + static_assert(Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)); // Skip unused kernel variants for faster compilation: - if (use_softcap && !(D == 128 || D == 256)) { + if (use_softcap && !(Dk == 128 || Dk == 256)) { NO_DEVICE_CODE; return; } @@ -58,11 +59,11 @@ static __global__ void flash_attn_ext_f16( const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. - static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); + static_assert(Dk <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16; - static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); + static_assert(Dk % frag_m == 0, "If ncols == 8 then Dk % frag_m must be 0."); typedef nvcuda::wmma::fragment frag_a_K; typedef nvcuda::wmma::fragment frag_a_V; typedef nvcuda::wmma::fragment frag_b; @@ -74,30 +75,32 @@ static __global__ void flash_attn_ext_f16( static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps."); // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: - constexpr int D_padded = D + 8; + constexpr int Dk_padded = Dk + 8; + constexpr int Dv_padded = Dv + 8; constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); - const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half * V_h = (const half *) (V + nb22*(blockIdx.y / gqa_ratio)); // K and V have same shape const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); - const int stride_Q = nb01 / sizeof(float); - const int stride_KV = nb11 / sizeof(half); + const int stride_Q = nb01 / sizeof(float); + const int stride_K = nb11 / sizeof(half); + const int stride_V = nb21 / sizeof(half); const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); const half2 slope2 = make_half2(slopef, slopef); const half2 softcap_2 = make_half2(softcap, softcap); - frag_b Q_b[D/16][ncols/frag_n]; + frag_b Q_b[Dk/16][ncols/frag_n]; // A single buffer for temporarily holding tiles of KQ and VKQ parts: constexpr int mem_KQ = ncols*kqs_padded*kqar; - constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; + constexpr int mem_VKQ_parts = VKQ_ratio*ncols*Dv_padded; __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; float * KQ_f = (float *) KQ; half2 * KQ2 = (half2 *) KQ; @@ -120,18 +123,18 @@ static __global__ void flash_attn_ext_f16( KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF); } - __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. + __shared__ half VKQ[ncols*Dv_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D/2 && i >= D/2) { + if (i0 + WARP_SIZE > Dv/2 && i >= Dv/2) { break; } - VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); + VKQ2[j*(Dv_padded/2) + i] = make_half2(0.0f, 0.0f); } } @@ -140,12 +143,12 @@ static __global__ void flash_attn_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { + if (i0 + WARP_SIZE > Dk && i >= Dk) { break; } - KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + KQ[j*Dk_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; } } @@ -153,10 +156,10 @@ static __global__ void flash_attn_ext_f16( // Load Q into tensor core fragments/registers since it will be used frequently: #pragma unroll - for (int i0 = 0; i0 < D; i0 += 16) { + for (int i0 = 0; i0 < Dk; i0 += 16) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); + nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*Dk_padded + i0, Dk_padded); } } @@ -173,9 +176,9 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); } #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk; k_KQ_0 += 16) { frag_a_K K_a; - nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); + nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_K + k_KQ_0, stride_K); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); @@ -309,9 +312,9 @@ static __global__ void flash_attn_ext_f16( } } - frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n]; + frag_c_VKQ VKQ_c[Dv/VKQ_stride][ncols/frag_n]; #pragma unroll - for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { + for (int i_VKQ_0 = 0; i_VKQ_0 < Dv; i_VKQ_0 += VKQ_stride) { #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); @@ -322,7 +325,7 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; frag_a_V v_a; - nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); + nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_V + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_V); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); @@ -332,15 +335,15 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); - const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded); + const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*Dk_padded); #pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { + for (int i_KQ_0 = 0; i_KQ_0 < Dk; i_KQ_0 += VKQ_stride) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { nvcuda::wmma::store_matrix_sync( - KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), + KQ + offset_k + j0*Dk_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], - D_padded, nvcuda::wmma::mem_col_major); + Dk_padded, nvcuda::wmma::mem_col_major); } } @@ -358,18 +361,18 @@ static __global__ void flash_attn_ext_f16( } #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D/2 && i >= D/2) { + if (i0 + WARP_SIZE > Dv/2 && i >= Dv/2) { break; } half2 VKQ_add = make_half2(0.0f, 0.0f); #pragma unroll for (int l = 0; l < VKQ_ratio; ++l) { - VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; + VKQ_add += KQ2[l*(ncols*Dk_padded/2) + j*(Dk_padded/2) + i]; } - VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add; + VKQ2[j*(Dv_padded/2) + i] = VKQ_scale*VKQ2[j*(Dv_padded/2) + i] + VKQ_add; } } @@ -392,16 +395,16 @@ static __global__ void flash_attn_ext_f16( } #pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { + if (i0 + WARP_SIZE > Dv && i >= Dv) { break; } - float dst_val = VKQ[j_VKQ*D_padded + i]; + float dst_val = VKQ[j_VKQ*Dv_padded + i]; if (parallel_blocks == 1) { dst_val /= KQ_rowsum_j; } - dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val; + dst[j_dst*gridDim.y*Dv + blockIdx.y*Dv + i] = dst_val; } if (parallel_blocks == 1 || threadIdx.x != 0) { @@ -446,13 +449,13 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); -template +template void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; constexpr int nwarps = 4; - constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16; + constexpr int frag_m = cols_per_block == 8 && Dk % 32 == 0 ? 32 : 16; const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; @@ -462,29 +465,33 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm if (4*blocks_num_pb1 < 2*nsm) { constexpr int parallel_blocks = 4; fattn_kernel_t fattn_kernel = softcap == 0.0f ? - flash_attn_ext_f16 : - flash_attn_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + flash_attn_ext_f16 : + flash_attn_ext_f16; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } if (2*blocks_num_pb1 < 2*nsm) { constexpr int parallel_blocks = 2; fattn_kernel_t fattn_kernel = softcap == 0.0f ? - flash_attn_ext_f16 : - flash_attn_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + flash_attn_ext_f16 : + flash_attn_ext_f16; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } constexpr int parallel_blocks = 1; fattn_kernel_t fattn_kernel = softcap == 0.0f ? - flash_attn_ext_f16 : - flash_attn_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + flash_attn_ext_f16 : + flash_attn_ext_f16; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } #define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \ template void ggml_cuda_flash_attn_ext_wmma_f16_case \ - (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +#define DECL_FATTN_WMMA_F16_CASE_DKDV(Dk, Dv, cols_per_block, KQ_acc_t) \ + template void ggml_cuda_flash_attn_ext_wmma_f16_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float); extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float); @@ -518,3 +525,7 @@ extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half); extern DECL_FATTN_WMMA_F16_CASE(112, 32, half); extern DECL_FATTN_WMMA_F16_CASE(128, 32, half); extern DECL_FATTN_WMMA_F16_CASE(256, 16, half); + +extern DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 8, half); +extern DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 16, half); +extern DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 32, half); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index c15d6c81..6329908d 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -12,6 +12,14 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * V = dst->src[2]; + + if (Q->ne[0] != V->ne[0]) { + if (!((Q->ne[0] == 192 && V->ne[0] == 128) || (Q->ne[0] == 576 && V->ne[0] == 512))) { + fprintf(stderr, "======================= %s: Unhandled head size combination %d, %d\n", __func__, (int)Q->ne[0], (int)V->ne[0]); + GGML_ABORT("fatal error"); + } + } const int32_t precision = KQV->op_params[3]; @@ -20,22 +28,25 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g constexpr int cols_per_block = 16; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, float>(ctx, dst); break; case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, float>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, float>(ctx, dst); break; case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, float>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, float>(ctx, dst); break; case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, float>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, float>(ctx, dst); break; default: fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); @@ -46,19 +57,22 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g constexpr int cols_per_block = 32; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, float>(ctx, dst); break; case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, float>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, float>(ctx, dst); break; case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, float>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, float>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, float>(ctx, dst); break; // case 256: // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); @@ -76,16 +90,19 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g constexpr int cols_per_block = 8; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst); break; case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst); break; default: fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); @@ -99,22 +116,25 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g constexpr int cols_per_block = 16; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst); break; case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, half>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst); break; case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, half>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst); break; case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst); break; default: fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); @@ -127,22 +147,25 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g constexpr int cols_per_block = 32; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst); break; case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, half>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst); break; case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, half>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst); break; case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst); break; default: fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); @@ -152,7 +175,13 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g } #define FATTN_VEC_F16_CASE(D, type_K, type_V) \ if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ - ggml_cuda_flash_attn_ext_vec_f16_case(ctx, dst); \ + ggml_cuda_flash_attn_ext_vec_f16_case(ctx, dst); \ + return; \ + } \ + +#define FATTN_VEC_F16_CASE_DKDV(Dk, Dv, type_K, type_V) \ + if (Q->ne[0] == (Dk) && V->ne[0] == Dv && K->type == (type_K) && V->type == (type_V)) { \ + ggml_cuda_flash_attn_ext_vec_f16_case(ctx, dst); \ return; \ } \ @@ -218,6 +247,9 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0) FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q6_0) FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0) + + FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) #else FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) @@ -231,14 +263,24 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) FATTN_VEC_F16_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0) FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0) + + FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + #endif // GGML_CUDA_FA_ALL_QUANTS - on_no_fattn_vec_case(Q->ne[0]); + on_no_fattn_vec_case(Q->ne[0], V->ne[0]); } #define FATTN_VEC_F32_CASE(D, type_K, type_V) \ if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ - ggml_cuda_flash_attn_ext_vec_f32_case(ctx, dst); \ + ggml_cuda_flash_attn_ext_vec_f32_case(ctx, dst); \ + return; \ + } \ + +#define FATTN_VEC_F32_CASE_DKDV(Dk, Dv, type_K, type_V) \ + if (Q->ne[0] == (Dk) && V->ne[0] == Dv && K->type == (type_K) && V->type == (type_V)) { \ + ggml_cuda_flash_attn_ext_vec_f32_case(ctx, dst); \ return; \ } \ @@ -298,6 +340,9 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) + + FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) #else FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) @@ -306,9 +351,12 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) + + FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) #endif // GGML_CUDA_FA_ALL_QUANTS - on_no_fattn_vec_case(Q->ne[0]); + on_no_fattn_vec_case(Q->ne[0], V->ne[0]); } void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs192-f16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs192-f16-f16.cu new file mode 100644 index 00000000..7dda0133 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs192-f16-f16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs192-q8_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs192-q8_0-q8_0.cu new file mode 100644 index 00000000..740ac37d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs192-q8_0-q8_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs192-f16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs192-f16-f16.cu new file mode 100644 index 00000000..1ea24302 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs192-f16-f16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs192-q8_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs192-q8_0-q8_0.cu new file mode 100644 index 00000000..6be4d042 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs192-q8_0-q8_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu index 2d94e65c..334e1deb 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu @@ -8,3 +8,5 @@ DECL_FATTN_WMMA_F16_CASE(96, 16, float); DECL_FATTN_WMMA_F16_CASE(112, 16, float); DECL_FATTN_WMMA_F16_CASE(128, 16, float); DECL_FATTN_WMMA_F16_CASE(256, 16, float); + +DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 16, float); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu index c3d9df3c..1faf3c9b 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu @@ -7,3 +7,5 @@ DECL_FATTN_WMMA_F16_CASE(80, 32, float); DECL_FATTN_WMMA_F16_CASE(96, 32, float); DECL_FATTN_WMMA_F16_CASE(112, 32, float); DECL_FATTN_WMMA_F16_CASE(128, 32, float); + +DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 32, float); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu index bb680e40..48973618 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu @@ -8,3 +8,5 @@ DECL_FATTN_WMMA_F16_CASE(96, 16, half); DECL_FATTN_WMMA_F16_CASE(112, 16, half); DECL_FATTN_WMMA_F16_CASE(128, 16, half); DECL_FATTN_WMMA_F16_CASE(256, 16, half); + +DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 16, half); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu index 073f71b1..ed92963e 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu @@ -8,3 +8,5 @@ DECL_FATTN_WMMA_F16_CASE(96, 32, half); DECL_FATTN_WMMA_F16_CASE(112, 32, half); DECL_FATTN_WMMA_F16_CASE(128, 32, half); DECL_FATTN_WMMA_F16_CASE(256, 32, half); + +DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 32, half); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu index d30710c5..4e221003 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu @@ -6,3 +6,5 @@ DECL_FATTN_WMMA_F16_CASE(64, 8, half); DECL_FATTN_WMMA_F16_CASE(96, 8, half); DECL_FATTN_WMMA_F16_CASE(128, 8, half); DECL_FATTN_WMMA_F16_CASE(256, 8, half); + +DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 8, half); diff --git a/src/llama.cpp b/src/llama.cpp index 5ac44055..e246dec9 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8777,7 +8777,11 @@ static struct ggml_tensor * llm_build_kqv( cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); - if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) { + // Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA + // For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG. + // Not sure if it is really a matter of insufficient precision, or I have made a mistake in the fattn-vec-f16 kernel. + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || + (model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8)) { ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); } //ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); From c67a37b251fc22b0f8b8313ea5c76a73ff6ed49f Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 7 Mar 2025 08:54:09 +0200 Subject: [PATCH 031/132] Custom quantization rules with regular expressions (#244) * Custom quantization rules with regular expressions * Add the --custom-q option to the help --------- Co-authored-by: Iwan Kawrakow --- examples/quantize/quantize.cpp | 31 +++++++++++++++++++++++++++++++ include/llama.h | 1 + src/llama.cpp | 19 +++++++++++++++++++ 3 files changed, 51 insertions(+) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 916f57ec..89de794b 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -144,6 +144,7 @@ static void usage(const char * executable) { printf(" --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n"); printf(" --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor.\n"); printf(" --token-embedding-type ggml_type: use this ggml_type for the token_embd.weight tensor.\n\n"); + printf(" --custom-q regex1=type1,regex2=type2...: use this to specify custom quantization type rules.\n\n"); printf("Additional specific tensor quantization types used in the custom quant scheme 'CQS (default is Q2_K):\n"); printf(" --attn-q-type ggml_type: use this ggml_type for the attn_q.weight tensor.\n"); printf(" --attn-k-type ggml_type: use this ggml_type for the attn_k.weight tensor.\n"); @@ -290,6 +291,28 @@ static ggml_type parse_ggml_type(const char * arg) { return result; } +using CustomQ = std::pair; + +static bool parse_custom_quants(const std::string& arg, std::vector& custom_quants) { + for (const auto & item : string_split(arg, ',')) { + auto pos = item.find('='); + if (pos == std::string::npos) { + fprintf(stderr, "Invalid custom quantization input %s\n", arg.c_str()); + return false; + } + auto pattern = item.substr(0, pos); + auto type_as_string = item.substr(pos + 1); + auto type = parse_ggml_type(type_as_string.c_str()); + if (type == GGML_TYPE_COUNT) { + fprintf(stderr, "Invalid quantization type '%s' in custom quantization input %s\n", type_as_string.c_str(), item.c_str()); + return false; + } + printf("Adding custom rule %s -> %s\n", pattern.c_str(), ggml_type_name(type)); + custom_quants.emplace_back(std::move(pattern), type); + } + return true; +} + int main(int argc, char ** argv) { if (argc < 3) { usage(argv[0]); @@ -301,6 +324,7 @@ int main(int argc, char ** argv) { std::string imatrix_file; std::vector included_weights, excluded_weights; std::vector kv_overrides; + std::vector custom_quants; for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) { if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) { @@ -371,6 +395,10 @@ int main(int argc, char ** argv) { if (arg_idx == argc-1 || !string_parse_kv_override(argv[++arg_idx], kv_overrides)) { usage(argv[0]); } + } else if (strcmp(argv[arg_idx], "--custom-q") == 0) { + if (arg_idx == argc-1 || !parse_custom_quants(argv[++arg_idx], custom_quants)) { + usage(argv[0]); + } } else if (strcmp(argv[arg_idx], "--allow-requantize") == 0) { params.allow_requantize = true; } else if (strcmp(argv[arg_idx], "--pure") == 0) { @@ -451,6 +479,9 @@ int main(int argc, char ** argv) { kv_overrides.back().key[0] = 0; params.kv_overrides = &kv_overrides; } + if (!custom_quants.empty()) { + params.custom_quants = &custom_quants; + } llama_backend_init(); diff --git a/include/llama.h b/include/llama.h index 38a12744..5e86cb68 100644 --- a/include/llama.h +++ b/include/llama.h @@ -418,6 +418,7 @@ extern "C" { bool ignore_imatrix_rules; // If set to true, the built-in rules for refusing to quantize into certain quants without imatrix are ignored void * imatrix; // pointer to importance matrix data void * kv_overrides; // pointer to vector containing overrides + void * custom_quants; // pointer to vector containing custom quantization rules } llama_model_quantize_params; // grammar types diff --git a/src/llama.cpp b/src/llama.cpp index e246dec9..9c9739e9 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -16283,6 +16283,19 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2; }; + auto custom_type = GGML_TYPE_COUNT; + if (qs.params->custom_quants) { + using CustomQ = std::pair; + auto& q_rules = *static_cast*>(qs.params->custom_quants); + for (auto& rule : q_rules) { + std::regex pattern(rule.first); + if (std::regex_search(name, pattern)) { + custom_type = rule.second; + break; + } + } + } + //auto get_layer = [] (const char * name) { // int il; // if (sscanf(name, "blk.%d.", &il) == 1) return il; @@ -16752,6 +16765,11 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n ++qs.i_ffn_up; } + if (custom_type < GGML_TYPE_COUNT) { + new_type = custom_type; + LLAMA_LOG_INFO("Using custom type %s for tensor %s\n", ggml_type_name(new_type), name.c_str()); + } + // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; //} // IK: let's remove this, else Q2_K is almost the same as Q3_K_S @@ -17791,6 +17809,7 @@ struct llama_model_quantize_params llama_model_quantize_default_params() { /*.ignore_imatrix_rules =*/ false, /*.imatrix =*/ nullptr, /*.kv_overrides =*/ nullptr, + /*.custom_quants =*/ nullptr, }; return result; From 3d85a1d66302989401f92a5ae347577b03cbdaa7 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 7 Mar 2025 09:46:58 +0200 Subject: [PATCH 032/132] Better FlashMLA (#243) * This is a better FA for TG It should benefit MLA and GQA. Tested to work with DeepSeek-Lite MLA, not yet for GQA. For tg64@pp8192 it is ~13% faster than MLA without FA, and 57% faster that the main branch FA. * WIP * Cleanup --------- Co-authored-by: Iwan Kawrakow --- ggml/src/CMakeLists.txt | 4 +- ggml/src/ggml.c | 112 ++++++++----- ggml/src/iqk/iqk_common.h | 138 ++++++++++++++++ ggml/src/iqk/iqk_flash_attn.cpp | 194 ++++++++++++++++++++++ ggml/src/iqk/iqk_flash_impl.h | 23 +++ ggml/src/iqk/iqk_mul_mat.cpp | 274 ++++++++++++++++---------------- ggml/src/iqk/iqk_mul_mat.h | 16 +- 7 files changed, 582 insertions(+), 179 deletions(-) create mode 100644 ggml/src/iqk/iqk_common.h create mode 100644 ggml/src/iqk/iqk_flash_attn.cpp create mode 100644 ggml/src/iqk/iqk_flash_impl.h diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 0ed84956..c1ebd870 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -258,8 +258,8 @@ set (GGML_HEADERS_IQK iqk/iqk_config.h) if (GGML_IQK_MUL_MAT) message(STATUS "Using optimized iqk matrix multiplications") add_compile_definitions(GGML_USE_IQK_MULMAT) - set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp) - set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h) + set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp iqk/iqk_flash_attn.cpp) + set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h iqk/iqk_flash_impl.h) if (GGML_IQK_FA_ALL_QUANTS) message(STATUS "Including all IQK FA kernels") add_compile_definitions(GGML_IQK_FA_ALL_QUANTS) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 46e1a548..e5ad15f2 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -17870,46 +17870,57 @@ static void ggml_compute_forward_flash_attn_ext_f16( } #if GGML_USE_IQK_MULMAT - if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) { - //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n", - // k->ne[0], k->ne[1], k->ne[2], q->ne[0], q->ne[1], q->ne[2], v->ne[0], v->ne[1], v->ne[2], mask->ne[0], mask->ne[1], mask->ne[2]); - // I keep changing my mind what is the best strategy to split the threads when processing - // multiple heads. This is my current thinking, the commented out code below was the previous. - int ntg = nth/simple_gcd(neq2*neq3, nth); - int64_t neq1g = (neq1 + ntg - 1)/ntg; - //int64_t work_per_slice = D*nek1*neq1; - //int ntg = 1; - // - // When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix - // But we also want each thread to process the same amount of rows, so neq1 must be a multiple of - // the number of threads processing the (iq2, iq3) matrix. - // - //if (neq1 >= 8*nth) { - // if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8; - // else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4; - // else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2; - //} - int counter = 0; - for (int64_t iq3 = 0; iq3 < neq3; iq3++) { - for (int64_t iq2 = 0; iq2 < neq2; iq2++) { - if (counter++ % (nth/ntg) == ith/ntg) { - int iq1 = (ith%ntg)*neq1g; - int this_neq1 = MIN(neq1g, neq1-iq1); - if (!iqk_flash_attn_noalibi(k->type, v->type, - Dk, Dv, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), - (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]), - (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]), - (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]), - (const void *)((const char *)mask->data + iq1*mask->nb[1]), - scale, softcap, - (float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable; - } - } - } - return; -IQK_Flash_Attn_NotAvailable:; - printf("iqk_flash was rejected\n"); - } + if (iqk_flash_attn_noalibi(q->type, mask->type, max_bias, + q->ne[3], q->ne[2], q->nb[3], q->nb[2], + k->ne[3], k->ne[2], k->nb[3], k->nb[2], + v->ne[3], v->ne[2], v->nb[3], v->nb[2], + dst->ne[2], dst->ne[1], dst->nb[1], + k->type, v->type, + Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], + q->data, k->data, v->data, mask->data, + scale, softcap, (float *)dst->data, + params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth)) return; + +// if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) { +// //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n", +// // k->ne[0], k->ne[1], k->ne[2], q->ne[0], q->ne[1], q->ne[2], v->ne[0], v->ne[1], v->ne[2], mask->ne[0], mask->ne[1], mask->ne[2]); +// // I keep changing my mind what is the best strategy to split the threads when processing +// // multiple heads. This is my current thinking, the commented out code below was the previous. +// int ntg = nth/simple_gcd(neq2*neq3, nth); +// int64_t neq1g = (neq1 + ntg - 1)/ntg; +// //int64_t work_per_slice = D*nek1*neq1; +// //int ntg = 1; +// // +// // When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix +// // But we also want each thread to process the same amount of rows, so neq1 must be a multiple of +// // the number of threads processing the (iq2, iq3) matrix. +// // +// //if (neq1 >= 8*nth) { +// // if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8; +// // else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4; +// // else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2; +// //} +// int counter = 0; +// for (int64_t iq3 = 0; iq3 < neq3; iq3++) { +// for (int64_t iq2 = 0; iq2 < neq2; iq2++) { +// if (counter++ % (nth/ntg) == ith/ntg) { +// int iq1 = (ith%ntg)*neq1g; +// int this_neq1 = MIN(neq1g, neq1-iq1); +// if (!iqk_flash_attn_noalibi(k->type, v->type, +// Dk, Dv, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), +// (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]), +// (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]), +// (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]), +// (const void *)((const char *)mask->data + iq1*mask->nb[1]), +// scale, softcap, +// (float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable; +// } +// } +// } +// return; +//IQK_Flash_Attn_NotAvailable:; +// printf("iqk_flash was rejected\n"); +// } #endif const uint32_t n_head = neq2; @@ -21534,6 +21545,27 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa const int64_t D = MAX(Dk, Dv); cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread +#if GGML_USE_IQK_MULMAT + const struct ggml_tensor * q = node->src[0]; + const struct ggml_tensor * k = node->src[1]; + if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) { + int nstep_k = k->ne[1]/32; + int gcd_k = simple_gcd(nstep_k, n_tasks); + if (gcd_k > 1) { + int nth_k = n_tasks/gcd_k; + int rk2 = q->ne[2]/k->ne[2]; + if (rk2%nth_k == 0) { + size_t size = (Dv + 16)*rk2/nth_k*sizeof(float)*n_tasks; + if (ggml_is_quantized(k->type)) { + enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type; + size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]); + size += q->ne[2]*row_size; + } + cur = MAX(cur, size); + } + } + } +#endif } break; case GGML_OP_FLASH_ATTN_BACK: { diff --git a/ggml/src/iqk/iqk_common.h b/ggml/src/iqk/iqk_common.h new file mode 100644 index 00000000..dc3e369f --- /dev/null +++ b/ggml/src/iqk/iqk_common.h @@ -0,0 +1,138 @@ +// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- +// vi: set et ft=cpp fenc=utf-8 :vi +// +// +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + +#include "iqk_config.h" + +#if defined IQK_IMPLEMENT + +#include +#include +#include + +#include "ggml-impl.h" +#include "ggml-quants.h" +#include "iqk_mul_mat.h" +#include "iqk_quantize.h" + +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" + +#define FA_TIMING 0 + +#include +#include +#if FA_TIMING +#include +#include +struct Perf { + using TimePoint = std::chrono::time_point; + std::array times = {}; + std::mutex mutex; + bool report; + static auto cur_time() { return std::chrono::high_resolution_clock::now(); } + inline void accum(int what, const TimePoint& t1) { + auto t2 = cur_time(); + auto dt = delta(t1, t2); + std::lock_guard lock(mutex); + times[what] += dt; + } + inline void accum_nolock(int what, const TimePoint& t1) { + auto t2 = cur_time(); + auto dt = delta(t1, t2); + times[what] += dt; + } + inline void add(const Perf& other) { + std::lock_guard lock(mutex); + for (int i = 0; i < int(times.size()); ++i) times[i] += other.times[i]; + } + Perf(bool r) : report(r) {} + ~Perf() { + if (report) { + double tot = 0; + for (auto& t : times) tot += t; + if (!tot) return; + printf("======================= Timing: %g ms in total\n", tot); + for (int i = 0; i < int(times.size()); ++i) { + if (times[i]) { + printf("%d: %g ms -> %g%c\n", i, times[i], 100*times[i]/tot, '%'); + } + } + } + } + static Perf& instance() { + static Perf p(true); + return p; + } + static double delta(const TimePoint& t1, const TimePoint& t2) { + return 1e-6*std::chrono::duration_cast(t2-t1).count(); + } +}; +#endif + +#ifdef __AVX2__ +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) +#endif + +namespace { + +typedef struct { + int32_t i1; + int32_t i2; +} mmid_row_mapping; + +struct DataInfo { + float * s; + const char * cy; + size_t bs; + size_t by; + int cur_y = 0; + int ne11; + const mmid_row_mapping * row_mapping = nullptr; + size_t bs2 = 0; + + inline const char * src1_row(int iy) const { + if (!row_mapping) return cy + (cur_y + iy)*by; + int i11 = row_mapping[cur_y + iy].i1 % ne11; + int i12 = row_mapping[cur_y + iy].i2; + return cy + (i11 + i12*ne11)*by; + } + + inline void store(int ix, int iy, float result) const { + *(dst_row(iy) + ix) = result; + } +#ifdef __AVX__ + inline void store(int ix, int iy, __m128 result) const { + _mm_storeu_ps(dst_row(iy) + ix, result); + } + inline void store(int ix, int iy, __m256 result) const { + _mm256_storeu_ps(dst_row(iy) + ix, result); + } +#endif +#ifdef __AVX512F__ + inline void store(int ix, int iy, __m512 result) const { + _mm512_storeu_ps(dst_row(iy) + ix, result); + } +#endif +#ifdef __ARM_NEON + inline void store(int ix, int iy, float32x4_t result) const { + vst1q_f32(dst_row(iy) + ix, result); + } +#endif + inline float * dst_row(int iy) const { + if (!row_mapping) return s + (cur_y + iy)*bs; + int i12 = row_mapping[cur_y + iy].i2; + int i1 = row_mapping[cur_y + iy].i1; + int i2 = i12; + return s + i1*bs + i2*bs2; + } +}; + +typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x); + +#endif diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp new file mode 100644 index 00000000..fecd818b --- /dev/null +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -0,0 +1,194 @@ +#include "iqk_config.h" +#include "iqk_mul_mat.h" +#include "iqk_flash_impl.h" + +#ifdef IQK_IMPLEMENT + +#include +#include +#include +#include +#include +#include + +namespace { +inline uint32_t simple_gcd(uint32_t a, uint32_t b) { + while (a != b) { + if (a > b) a -= b; + else b -= a; + } + return a; +} +} + +// TODO: get the ggml_type enum here without polution +// +bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, + int neq3, int neq2, long nbq3, long nbq2, + int nek3, int nek2, long nbk3, long nbk2, + int nev3, int nev2, long nbv3, long nbv2, + int ne2, int ne1, long nb1, + int int_type_k, // type of k + int int_type_v, // type of v + int Dk, // K head size + int Dv, // V head size + int neq1, // number of columns in q + int nek1, // number of rows in k + int stride_q, // distance between q columns in bytes + int stride_k, // distance between k rows in bytes + int stride_v, // distance between v rows in bytes + int stride_m, // distance between mask rows (in bytes + const void * q, // q matrix. + const void * k, // k matrix. Assumed to be fp16, nq x nk elements + const void * v, // v matrix. Assumed to be fp16, nq x nk elements + const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + float scale, // scale applied before softmax + float softcap, // if > 0, a "soft-cap" operation is applied before softmax + float * qkv, // v*softmax(scale*(k*q)) + [[maybe_unused]] void * work_buffer, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data, + int ith, int nth) { + + if (type_q != 0 || type_mask != 1 || max_bias > 0) return false; + + int rk2 = neq2/nek2; + int rv2 = neq2/nev2; + int rk3 = neq3/nek3; + int rv3 = neq3/nev3; + + // Getting confused all the time about where to load data from and store the results to + // (especially when combining the results from the threads). + // So, for now, making it work just for MLA (nek2 = 1). + // I think it would also speed up things for GQA, but I'm leaving this for another day. + if (neq3 == 1 && rk2 > 1 && neq1 == 1 && nth >= 1 && nek1/32 > 1 && nek2 == 1) { + int nstep_k = nek1/32; + int gcd_k = simple_gcd(nstep_k, nth); + if (gcd_k >= 1) { + int nth_k = nth/gcd_k; + if (rk2%nth_k == 0) { + int ith_k = ith%gcd_k; + int ith_q = ith/gcd_k; + auto kth = (const char *)k + ith_k*(nek1/gcd_k)*stride_k; + auto vth = (const char *)v + ith_k*(nek1/gcd_k)*stride_v; + auto qth = (const char *)q + ith_q*(rk2/nth_k)*nbq2; + auto mth = (const char *)mask + ith_k*(nek1/gcd_k)*sizeof(uint16_t); // we don't have ggml_half available here + auto work = (char *)work_buffer; + + // Each thread will produce a result of size Dv*(rk2/nth_k)*sizeof(float) + // In addition, we need M, S for the rk2/nth_k rows the thread is processing + // => (Dv + 2)*rk2/nth_k*sizeof(float). We use (Dv + 16) instead to make sure threads are not + // writing onto the same cache line. + auto size_thread = (Dv + 16)*rk2/nth_k*sizeof(float); + auto result_buffer = work; + auto work_this_thread = (float *)(result_buffer + ith*size_thread); + if (!iqk_flash_attn_impl(int_type_k, int_type_v, + Dk, Dv, rk2/nth_k, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv, + (const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth, + scale, softcap, + work_this_thread, work_this_thread + (Dv+0)*rk2/nth_k, work_this_thread + (Dv+1)*rk2/nth_k)) return false; + + barrier(barrier_data); + + // TODO: simdify this + for (int j = ith; j < rk2; j += nth) { + auto Racc = qkv + j*nb1/sizeof(float); + float M = -INFINITY, S = 0; + int jth_q = j/(rk2/nth_k); + int jj = j%(rk2/nth_k); + for (int j1 = 0; j1 < rk2/nth_k; ++j1) { + auto R = (const float *)(result_buffer + (jth_q*(rk2/nth_k) + j1)*size_thread); + auto Mj = R + Dv*rk2/nth_k; + auto Sj = Mj + rk2/nth_k; + R += jj*Dv; + if (Mj[jj] == -INFINITY) continue; + if (Mj[jj] > M) { + if (M == -INFINITY) { + std::memcpy(Racc, R, Dv*sizeof(float)); + S = Sj[jj]; + } else { + float c = exp(M - Mj[jj]); + S = c*S + Sj[jj]; + for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i]; + } + M = Mj[jj]; + } else { + float c = exp(Mj[jj] - M); + S += c*Sj[jj]; + for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i]; + } + } + float norm = S > 0 ? 1/S : 1; + for (int i = 0; i < Dv; ++i) Racc[i] *= norm; + } + return true; + + } + } + } + + // I keep changing my mind what is the best strategy to split the threads when processing + // multiple heads. This is my current thinking, the commented out code below was the previous. + int ntg = nth/simple_gcd(neq2*neq3, nth); + int neq1g = (neq1 + ntg - 1)/ntg; + //int64_t work_per_slice = D*nek1*neq1; + //int ntg = 1; + // + // When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix + // But we also want each thread to process the same amount of rows, so neq1 must be a multiple of + // the number of threads processing the (iq2, iq3) matrix. + // + //if (neq1 >= 8*nth) { + // if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8; + // else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4; + // else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2; + //} + int counter = 0; + for (int64_t iq3 = 0; iq3 < neq3; iq3++) { + for (int64_t iq2 = 0; iq2 < neq2; iq2++) { + if (counter++ % (nth/ntg) == ith/ntg) { + int iq1 = (ith%ntg)*neq1g; + int this_neq1 = std::min(neq1g, neq1-iq1); + if (!iqk_flash_attn_impl(int_type_k, int_type_v, + Dk, Dv, this_neq1, nek1, stride_q, stride_k, stride_v, stride_m, ne1*nb1/sizeof(float), + (const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q), + (const void *)((const char *)k + iq2/rk2*nbk2 + iq3/rk3*nbk3), + (const void *)((const char *)v + iq2/rv2*nbv2 + iq3/rv3*nbv3), + (const void *)((const char *)mask + iq1*stride_m), + scale, softcap, + (float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false; + } + } + } + + return true; +} + +#else + +bool iqk_flash_attn_noalibi([[maybe_unused]] int type_q, [[maybe_unused]] int type_mask, [[maybe_unused]] float max_bias, + [[maybe_unused]] int neq3, [[maybe_unused]] int neq2, [[maybe_unused]] long nbq3, [[maybe_unused]] long nbq2, + [[maybe_unused]] int nek3, [[maybe_unused]] int nek2, [[maybe_unused]] long nbk3, [[maybe_unused]] long nbk2, + [[maybe_unused]] int nev3, [[maybe_unused]] int nev2, [[maybe_unused]] long nbv3, [[maybe_unused]] long nbv2, + [[maybe_unused]] int ne2, [[maybe_unused]] int ne1, [[maybe_unused]] long nb1, + [[maybe_unused]] int int_type_k, // type of k + [[maybe_unused]] int int_type_v, // type of v + [[maybe_unused]] int D, // head size + [[maybe_unused]] int nq, // number of columns in q + [[maybe_unused]] int nk, // number of rows in k + [[maybe_unused]] int stride_q, // distance between q columns in bytes + [[maybe_unused]] int stride_k, // distance between k rows in bytes + [[maybe_unused]] int stride_v, // distance between v rows in bytes + [[maybe_unused]] int stride_m, // distance between mask rows (in bytes + [[maybe_unused]] const void * q, // q matrix. + [[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements + [[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements + [[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + [[maybe_unused]] float scale, // scale applied before softmax + [[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax + [[maybe_unused]] float * qkv, // v*softmax(scale*(k*q)) + [[maybe_unused]] void * work_buffer, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data, + [[maybe_unused]] int ith, [[maybe_unused]] int nth) { + return false; +} + +#endif + diff --git a/ggml/src/iqk/iqk_flash_impl.h b/ggml/src/iqk/iqk_flash_impl.h new file mode 100644 index 00000000..bc68e0d8 --- /dev/null +++ b/ggml/src/iqk/iqk_flash_impl.h @@ -0,0 +1,23 @@ +#pragma once + +bool iqk_flash_attn_impl(int type_k, // type of k + int type_v, // type of v + int Dk, // K head size + int Dv, // V head size + int nq, // number of columns in q + int nk, // number of rows in k + int stride_q, // distance between q columns in bytes + int stride_k, // distance between k rows in bytes + int stride_v, // distance between v rows in bytes + int stride_m, // distance between mask rows (in bytes + int stride_qkv, // distance between rows in mask (in bytes) + const float * q, // q matrix. + const void * k, // k matrix. Assumed to be fp16, nq x nk elements + const void * v, // v matrix. Assumed to be fp16, nq x nk elements + const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + float scale, // scale applied before softmax + float softcap, // if > 0, a "soft-cap" operation is applied before softmax + float * qkv, // v*softmax(scale*(k*q)) + float * M, + float * S); + diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 1f18837c..14cc64db 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -15870,23 +15870,23 @@ struct FlashMS { inline void update_M_S(int j, float32x4_t * vk) { float smax = load_and_scale(j, vk); update_M(j, smax); - update_S(j, vk); + if (M[j] > -INFINITY) update_S(j, vk); } inline void update_M_S(int j, float32x4_t * vk, const char * mask) { float smax = load_apply_mask_and_scale(j, vk, mask); update_M(j, smax); - update_S(j, vk); + if (M[j] > -INFINITY) update_S(j, vk); } #else inline void update_M_S(int j, F16::Data * vk) { float smax = load_and_scale(j, vk); update_M(j, smax); - update_S(j, vk); + if (M[j] > -INFINITY) update_S(j, vk); } inline void update_M_S(int j, F16::Data * vk, const char * mask) { float smax = load_apply_mask_and_scale(j, vk, mask); update_M(j, smax); - update_S(j, vk); + if (M[j] > -INFINITY) update_S(j, vk); } #endif @@ -16037,7 +16037,7 @@ struct FlashQKV { } } - inline void normalize_and_store(const FlashMS& fms, int j, const qkv_cache_t * R, float * qkv) const { + inline void normalize_and_store_1row(const FlashMS& fms, int j, const qkv_cache_t * R, float * qkv) const { GGML_ASSERT(fms.S[j] > 0); auto norm = F16::set1(1/fms.S[j]); //auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f); @@ -16047,21 +16047,43 @@ struct FlashQKV { } } - inline void normalize_and_store(const FlashMS& fms, int nq1, int stride_qkv, float * qkv) const { - auto R = qkv_cache; - for (int j = 0; j < nq1; ++j) { - normalize_and_store(fms, j, R, qkv); - qkv += stride_qkv; - R += D; + inline void normalize_and_store(const FlashMS& fms, int nq1, int stride_qkv, float * qkv, float * M, float * S) const { + if (M && S) { + std::memcpy(M, fms.M, nq1*sizeof(float)); + std::memcpy(S, fms.S, nq1*sizeof(float)); + auto R = qkv_cache; + for (int j = 0; j < nq1; ++j) { + std::memcpy(qkv, R, D*sizeof(float)); + qkv += stride_qkv; + R += D; + } + } else { + auto R = qkv_cache; + for (int j = 0; j < nq1; ++j) { + normalize_and_store_1row(fms, j, R, qkv); + qkv += stride_qkv; + R += D; + } } } - inline void normalize_and_store(const FlashMS& fms, int stride_qkv, float * qkv) const { - auto R = qkv_cache; - for (int j = 0; j < q_step; ++j) { - normalize_and_store(fms, j, R, qkv); - qkv += stride_qkv; - R += D; + inline void normalize_and_store(const FlashMS& fms, int stride_qkv, float * qkv, float * M, float * S) const { + if (M && S) { + std::memcpy(M, fms.M, q_step*sizeof(float)); + std::memcpy(S, fms.S, q_step*sizeof(float)); + auto R = qkv_cache; + for (int j = 0; j < q_step; ++j) { + std::memcpy(qkv, R, D*sizeof(float)); + qkv += stride_qkv; + R += D; + } + } else { + auto R = qkv_cache; + for (int j = 0; j < q_step; ++j) { + normalize_and_store_1row(fms, j, R, qkv); + qkv += stride_qkv; + R += D; + } } } @@ -16435,7 +16457,8 @@ template & fms, FlashQKV& fqkv, - const float * q, const char * mask, float * qkv) { + const float * q, const char * mask, float * qkv, + float * M, float * S) { #ifdef __aarch64__ float16_t q_f16[Dk*q_step]; #endif @@ -16459,11 +16482,12 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in vh.next_block(); mr += k_step*sizeof(ggml_half); } - fqkv.normalize_and_store(fms, stride_qkv, qkv); + fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); q += q_step*stride_q; mask += q_step*stride_m; qkv += q_step*stride_qkv; + if (M && S) { M += q_step; S += q_step; } } int n_left = nq1 - q_step*(nq1/q_step); if (n_left > 0) { @@ -16485,7 +16509,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in vh.next_block(); mr += k_step*sizeof(ggml_half); } - fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv); + fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S); } } @@ -16493,7 +16517,8 @@ template & fms, FlashQKV& fqkv, - const float * q, const char * mask, float * qkv) { + const float * q, const char * mask, float * qkv, + float * M, float * S) { typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)]; #if FA_TIMING Perf perf(false); @@ -16528,15 +16553,16 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, } #if FA_TIMING t1 = Perf::cur_time(); - fqkv.normalize_and_store(fms, stride_qkv, qkv); + fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); perf.accum_nolock(3, t1); #else - fqkv.normalize_and_store(fms, stride_qkv, qkv); + fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); #endif q += q_step*stride_q; mask += q_step*stride_m; qkv += q_step*stride_qkv; + if (M && S) { M += q_step; S += q_step; } } int n_left = nq1 - q_step*(nq1/q_step); if (n_left > 0) { @@ -16552,7 +16578,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, vh.next_block(); mr += k_step*sizeof(ggml_half); } - fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv); + fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S); } #if FA_TIMING Perf::instance().add(perf); @@ -16580,12 +16606,12 @@ struct FlashAttn { template void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, - const float * q, const char * mask, float * qkv) { + const float * q, const char * mask, float * qkv, [[maybe_unused]] float * M, [[maybe_unused]] float * S) { if constexpr (std::is_same_v> || std::is_same_v> || std::is_same_v> || std::is_same_v>) { compute_helper_q>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); } else if constexpr (std::is_same_v>) { if (nq1 >= 8) { @@ -16597,10 +16623,10 @@ struct FlashAttn { HelperQ80R8 khr4(nk1, kh); #endif compute_helper_q, VHelper, FlashQKfp32>( - khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); } else{ compute_helper_q>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); } } else if constexpr (std::is_same_v>) { @@ -16613,14 +16639,14 @@ struct FlashAttn { HelperQ8KVR8 khr4(nk1, kh); #endif compute_helper_q, VHelper, FlashQKfp32>( - khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); } else{ compute_helper_q>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); } } else { compute_helper>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); } } @@ -16987,7 +17013,7 @@ struct FlashAttnBF16 { template void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, - const float * q, const char * mask, float * qkv) { + const float * q, const char * mask, float * qkv, [[maybe_unused]] float * M, [[maybe_unused]] float * S) { ggml_bf16_t q_bf16[q_step*Dk]; #if FA_TIMING Perf perf(false); @@ -17023,7 +17049,7 @@ struct FlashAttnBF16 { #if FA_TIMING t1 = Perf::cur_time(); #endif - fqkv.normalize_and_store(fms, stride_qkv, qkv); + fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); #if FA_TIMING perf.accum_nolock(4, t1); #endif @@ -17046,7 +17072,7 @@ struct FlashAttnBF16 { vh.next_block(); mr += k_step*sizeof(ggml_half); } - fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv); + fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S); } #if FA_TIMING Perf::instance().add(perf); @@ -17060,32 +17086,32 @@ struct FlashAttnBF16 { template inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, - const float * q, const char * mask, float scale, float softcap, float * qkv) { + const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) { if (nk1 >= 256) { //4096) { if (nq1 >= 64) { FlashAttn fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); return; } if (nq1 >= 32) { FlashAttn fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); return; } if (nq1 >= 16) { FlashAttn fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); return; } } if (nq1 >= 8) { FlashAttn fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } else { FlashAttn fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } } @@ -17093,27 +17119,27 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str template inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, - float scale, float softcap, float * qkv) { + float scale, float softcap, float * qkv, float * M, float * S) { HelperBF16 kh(k, stride_k); HelperBF16 vh(v, stride_v); if (nk1 >= 4096) { if (nq1 >= 64) { FlashAttnBF16 fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); return; } else if (nq1 >= 16) { FlashAttnBF16 fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); return; } } if (nq1 >= 8) { FlashAttnBF16 fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } else { FlashAttnBF16 fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } } #endif @@ -17122,43 +17148,43 @@ template inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv, const float * q, const char * v, const char * mask, - float scale, float softcap, float * qkv) { + float scale, float softcap, float * qkv, float * M, float * S) { switch (type_v) { case GGML_TYPE_F16: { HelperF16 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; #ifdef __AVX512BF16__ case GGML_TYPE_BF16: { HelperBF16 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; #endif case GGML_TYPE_Q8_0: { HelperQ80 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_Q8_KV: { HelperQ8KV vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_Q6_0: { HelperQ60 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; #if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { HelperQ40 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_Q4_1: { HelperQ41 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_IQ4_NL: { HelperIQ4nl vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; #endif default: break; @@ -17169,37 +17195,37 @@ template inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, - float scale, float softcap, float * qkv) { + float scale, float softcap, float * qkv, float * M, float * S) { switch (type_k) { case GGML_TYPE_F16: { HelperF16 kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_Q8_0: { HelperQ80 kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_Q8_KV: { HelperQ8KV kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_Q6_0: { HelperQ60 kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; #if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { HelperQ40 kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_Q4_1: { HelperQ41 kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; case GGML_TYPE_IQ4_NL: { HelperIQ4nl kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; #endif default: break; @@ -17223,13 +17249,13 @@ inline bool flash_attn_is_supported(ggml_type type) { template inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, - const float * q, const char * mask, float scale, float softcap, float * qkv) { + const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) { if (nq1 % 8 == 0) { FlashAttn<576, 512, 8, step_k> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); } else { FlashAttn<576, 512, 1, step_k> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); } } @@ -17237,29 +17263,29 @@ template inline bool iqk_deepseek_helper(ggml_type type_k, int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, - float scale, float softcap, float * qkv) { + float scale, float softcap, float * qkv, float * M, float * S) { if (type_k == GGML_TYPE_Q8_0) { HelperQ80<576, step_k> kh((const char *)k, stride_k); HelperQ80<512, step_k> vh((const char *)v, stride_v); - iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); return true; } if (type_k == GGML_TYPE_Q6_0) { HelperQ60<576, step_k> kh((const char *)k, stride_k); HelperQ60<512, step_k> vh((const char *)v, stride_v); - iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); return true; } if (type_k == GGML_TYPE_Q8_KV) { HelperQ8KV<576, step_k> kh((const char *)k, stride_k); HelperQ8KV<512, step_k> vh((const char *)v, stride_v); - iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); return true; } if (type_k == GGML_TYPE_F16) { HelperF16<576, step_k> kh((const char *)k, stride_k); HelperF16<512, step_k> vh((const char *)v, stride_v); - iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); return true; } #ifdef __AVX512BF16__ @@ -17268,10 +17294,10 @@ inline bool iqk_deepseek_helper(ggml_type type_k, HelperBF16<512, step_k> vh((const char *)v, stride_v); if (nq1 % 8 == 0) { FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } else { FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } return true; } @@ -17281,24 +17307,27 @@ inline bool iqk_deepseek_helper(ggml_type type_k, } -bool iqk_flash_attn_noalibi(int int_type_k, // type of k - int int_type_v, // type of v - int Dk, // K head size - int Dv, // V head size - int nq1, // number of columns in q - int nk1, // number of rows in k - int stride_q, // distance between q columns in bytes - int stride_k, // distance between k rows in bytes - int stride_v, // distance between v rows in bytes - int stride_m, // distance between mask rows (in bytes - int stride_qkv, // distance between rows in mask (in bytes) - const float * q, // q matrix. - const void * k, // k matrix. Assumed to be fp16, nq x nk elements - const void * v, // v matrix. Assumed to be fp16, nq x nk elements - const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements - float scale, // scale applied before softmax - float softcap, // if > 0, a "soft-cap" operation is applied before softmax - float * qkv) { // v*softmax(scale*(k*q)) +#include "iqk_flash_impl.h" + +bool iqk_flash_attn_impl(int int_type_k, // type of k + int int_type_v, // type of v + int Dk, // K head size + int Dv, // V head size + int nq1, // number of columns in q + int nk1, // number of rows in k + int stride_q, // distance between q columns in bytes + int stride_k, // distance between k rows in bytes + int stride_v, // distance between v rows in bytes + int stride_m, // distance between mask rows (in bytes + int stride_qkv, // distance between rows in mask (in bytes) + const float * q, // q matrix. + const void * k, // k matrix. Assumed to be fp16, nq x nk elements + const void * v, // v matrix. Assumed to be fp16, nq x nk elements + const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements + float scale, // scale applied before softmax + float softcap, // if > 0, a "soft-cap" operation is applied before softmax + float * qkv, // v*softmax(scale*(k*q)) + float * M, float * S) { if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32 @@ -17309,13 +17338,13 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k GGML_ASSERT(type_k == type_v); stride_q /= sizeof(float); // q stride as float return iqk_deepseek_helper<32>(type_k, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, - q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv); + q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, M, S); } if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false; if (Dk != Dv && Dk != 192 && Dv != 128) return false; if (Dv != 64 && Dv != 96 && Dv != 128 && Dv != 256) return false; - if (Dk != 64 && Dk != 96 && Dk != 128 && Dk != 192 && Dv != 256) return false; + if (Dk != 64 && Dk != 96 && Dk != 128 && Dk != 192 && Dk != 256) return false; auto ck = (const char *)k; auto cv = (const char *)v; @@ -17329,15 +17358,15 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types switch (Dk) { case 64: - iqk_flash_helper_T< 64, 64, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 96: - iqk_flash_helper_T< 96, 96, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 128: - iqk_flash_helper_T<128, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 192: - iqk_flash_helper_T<192, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<192, 128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 256: - iqk_flash_helper_T<256, 256, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; default: return false; } @@ -17346,15 +17375,15 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types switch (Dk) { case 64: - iqk_flash_helper_T< 64, 64, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 96: - iqk_flash_helper_T< 96, 96, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 128: - iqk_flash_helper_T<128, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 192: - iqk_flash_helper_T<192, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<192, 128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 256: - iqk_flash_helper_T<256, 256, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; default: return false; } @@ -17366,21 +17395,21 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k if (nk1%64 == 0) { switch (Dk) { case 64: - iqk_flash_helper_T< 64, 64, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; // Disable until we fix accumulate_qkv for odd D/16 //case 80: // iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, 96, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; // Disable until we fix accumulate_qkv for odd D/16 //case 112: // iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: - iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 192: - iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 256: - iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; default: return false; } @@ -17388,21 +17417,21 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k } switch (Dk) { case 64: - iqk_flash_helper_T< 64, 64, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 64, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; // Disable until we fix accumulate_qkv for odd D/16 //case 80: // iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, 96, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 96, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; // Disable until we fix accumulate_qkv for odd D/16 //case 112: // iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: - iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 192: - iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; case 256: - iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; default: return false; } @@ -17437,25 +17466,4 @@ bool iqk_moe_fused_up_gate(long /*Nx*/, long /*Ny*/, long /*ne00*/, int /*ne11*/ return false; } - -bool iqk_flash_attn_noalibi([[maybe_unused]] int int_type_k, // type of k - [[maybe_unused]] int int_type_v, // type of v - [[maybe_unused]] int D, // head size - [[maybe_unused]] int nq, // number of columns in q - [[maybe_unused]] int nk, // number of rows in k - [[maybe_unused]] int stride_q, // distance between q columns in bytes - [[maybe_unused]] int stride_k, // distance between k rows in bytes - [[maybe_unused]] int stride_v, // distance between v rows in bytes - [[maybe_unused]] int stride_m, // distance between mask rows (in bytes - [[maybe_unused]] int stride_qkv, // distance between rows in mask (in bytes) - [[maybe_unused]] const float * q, // q matrix. - [[maybe_unused]] const void * k, // k matrix. Assumed to be fp16, nq x nk elements - [[maybe_unused]] const void * v, // v matrix. Assumed to be fp16, nq x nk elements - [[maybe_unused]] const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements - [[maybe_unused]] float scale, // scale applied before softmax - [[maybe_unused]] float softcap, // if > 0, a "soft-cap" operation is applied before softmax - [[maybe_unused]] float * qkv) { // v*softmax(scale*(k*q)) - return false; -} - #endif diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index 767f89cf..d91c4710 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -33,7 +33,14 @@ bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op, int typeB, const void * B, long strideB, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth); -bool iqk_flash_attn_noalibi(int type_k, // type of k +typedef void (*barrier_t) (void *); + +bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, + int neq3, int neq2, long nbq3, long nbq2, + int nek3, int nek2, long nbk3, long nbk2, + int nev3, int nev2, long nbv3, long nbv2, + int ne2, int ne1, long nb1, + int type_k, // type of k int type_v, // type of v int Dk, // K head size int Dv, // V head size @@ -43,14 +50,15 @@ bool iqk_flash_attn_noalibi(int type_k, // type of k int stride_k, // distance between k rows in bytes int stride_v, // distance between v rows in bytes int stride_m, // distance between mask rows (in bytes - int stride_qkv, // distance between rows in mask (in bytes) - const float * q, // q matrix. + const void * q, // q matrix. const void * k, // k matrix. Assumed to be fp16, nq x nk elements const void * v, // v matrix. Assumed to be fp16, nq x nk elements const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements float scale, // scale applied before softmax float softcap, // if > 0, a "soft-cap" operation is applied before softmax - float * qkv); // v*softmax(scale*(k*q)) + float * qkv, // v*softmax(scale*(k*q)) + void * work_buffer, barrier_t barrier, void * barrier_data, + int ith, int nth); #ifdef __cplusplus } From 81748fb55e474ef1ddb3c64c14f7c378f0f6cd8b Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 8 Mar 2025 19:33:41 +0200 Subject: [PATCH 033/132] Faster FlashMLA prompt processing (#246) * FlashMLA-2: faster prompt processing The current MLA implementation computes wv_b * (k_cache * softmax(k_cache * (wk_b*q))) This leads to 3.4X more multiply-adds (madds) compared to standard attention. Due to the resulting tensor shapes, TG is still faster than standard attention because the k_cache*(wk_b*q) and k_cache*(softmax(k_cache * (wk_b*q))) multiplications become GEMMs, so the additional madds are more than compensated for due to the much higher performance of GEMMs compared to GEMVs. But for PP, where we are dealing with GEMMs in both cases, the additional madds needed for MLA lead to lower performance, with the performance gap increasing with context length. So, then, when we are dealing with PP, we can rearrange the above to (wv_b * k_cache) * softmax( (wk_b^T*k_cache) * q), thus transforming it into the standard attention mechanism. We do need two additional matrix multiplications (which in practice is done as a single wkv_b * k_cache GEMM) with the *entire* K cache. But this is still cheaper than MLA, as we end up with 1.8X the madds required by standard attention. Oh, these figures are for the DeepSeek-V3/R1/Lite attention architecture. This leads to a significant PP performance increase compared to standard MLA with FA. There are many upsides to this: * If we only apply the above trick when we are processing more than X tokens (with suitable chosen X), TG performance stays the same as MLA with FA * We still need to store just the K-cache, so 576 entries per layer for DeepSeek-V3/R1/Lite * We get significantly better PP performance * We can use MLA+FA on CUDA. It works already with this commit for PP, something is not yet quite right for TG. The downside is that it only works with fp16 cache (for now). This is so because we need to convert the cache to fp32, else we cannot do the wkv_b * k_cache matrix multiplication (which in ggml requires the second operand to be fp32). But converting (copying) to fp32 only works for f16, bf16 and f32 tensors, so no luck with quantized cache. Another reason that we need to convert to fp32 is that the cache contains the RoPE'd portion, which we need to concatenate to the result of the wkv_b * k_cache matrix multiplication. Also this op works only when the tensors being concatenated are both fp32. So much about ggml being a general purpose ML library. * FlashMLA-2: on the CPU it now works for quantized cache except for q8_KV (q8_KV has row meta data, and there is still some confusion with row sizes because of that). * FlashMLA-2: on the CPU it now works also with q8_KV --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml.c | 41 ++++++++++++++ src/llama.cpp | 138 ++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 176 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e5ad15f2..4089e9b7 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -10565,6 +10565,42 @@ static void ggml_compute_forward_dup_bytes( } } +static void ggml_compute_forward_dup_q( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + struct ggml_tensor * src0 = dst->src[0]; + GGML_ASSERT(src0->ne[0] == dst->ne[0] && src0->nb[0] == ggml_type_size(src0->type)); + + ggml_to_float_t to_float = type_traits[src0->type].to_float; + GGML_ASSERT(to_float != NULL); + + int64_t nrows = ggml_nrows(dst); + int ith = params->ith; + int nth = params->nth; + + int64_t n_per_thread = (nrows + nth - 1)/nth; + int64_t first_row = ith*n_per_thread; + if (first_row >= nrows) return; + int64_t last_row = MIN(first_row + n_per_thread, nrows); + + for (int64_t ir = first_row; ir < last_row; ++ir) { + int64_t i03 = ir/(src0->ne[1]*src0->ne[2]); + int64_t i02 = (ir - i03*src0->ne[1]*src0->ne[2])/src0->ne[1]; + int64_t i01 = ir - i03*src0->ne[1]*src0->ne[2] - i02*src0->ne[1]; + int64_t i3 = ir/(dst->ne[1]*dst->ne[2]); + int64_t i2 = (ir - i3*dst->ne[1]*dst->ne[2])/dst->ne[1]; + int64_t i1 = ir - i3*dst->ne[1]*dst->ne[2] - i2*dst->ne[1]; + + const char * q = (const char *)src0->data + i03*src0->nb[3] + i02*src0->nb[2] + i01*src0->nb[1]; + char * f = ( char *)dst->data + i3* dst->nb[3] + i2* dst->nb[2] + i1* dst->nb[1]; + + to_float((const void *)q, (float *)f, src0->ne[0]); + } + +} + static void ggml_compute_forward_dup( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -10576,6 +10612,11 @@ static void ggml_compute_forward_dup( return; } + if (ggml_is_quantized(src0->type)) { + ggml_compute_forward_dup_q(params, dst); + return; + } + switch (src0->type) { case GGML_TYPE_F16: { diff --git a/src/llama.cpp b/src/llama.cpp index 9c9739e9..b45d9e70 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2691,6 +2691,9 @@ struct llama_kv_cache { // DeepSeek MLA std::vector kv_l; std::vector kvt_l; + ggml_tensor * kv_aux_f32 = nullptr; + ggml_tensor * k_aux = nullptr; + ggml_tensor * v_aux = nullptr; std::vector ctxs; std::vector bufs; @@ -3199,12 +3202,35 @@ static bool llama_kv_cache_init( if (cparams.mla_attn && model.layers[i].wk_b && model.layers[i].wv_b) { // DeepSeek MLA const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; const uint32_t kv_lora_rank = hparams.n_lora_kv; LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); if (cparams.flash_attn) { ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size); ggml_format_name(kv, "cache_kv_l%d", i); cache.kv_l.push_back(kv); + if (cparams.mla_attn > 1 && cache.kv_aux_f32 == nullptr) { + + cache.kv_aux_f32 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, + kv_lora_rank + n_embd_head_qk_rope, kv_size); + //(n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head, kv_size); + ggml_format_name(cache.kv_aux_f32, "kv_aux_f32%d", 0); + + cache.k_aux = ggml_new_tensor_3d(ctx, cache.type_k, hparams.n_embd_head_k, n_head, kv_size); + ggml_format_name(cache.k_aux, "k_aux%d", 0); + + cache.v_aux = ggml_new_tensor_3d(ctx, cache.type_k, hparams.n_embd_head_v, n_head, kv_size); + ggml_format_name(cache.v_aux, "v_aux%d", 0); + + //cache.kv_aux_2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, + // (hparams.n_embd_head_k + hparams.n_embd_head_v)*n_head, kv_size); + //ggml_format_name(cache.kv_aux, "kv_aux%d", 0); + //ggml_format_name(cache.kv_aux_2, "kv_aux%d", 2); + LLAMA_LOG_INFO("%s: allocated kv auxilary tensors as %ld x %ld, %ld x %ld x %ld, %ld x %ld x %ld\n", __func__, + cache.kv_aux_f32->ne[0], cache.kv_aux_f32->ne[1], + cache.k_aux->ne[0], cache.k_aux->ne[1], cache.k_aux->ne[2], + cache.v_aux->ne[0], cache.v_aux->ne[1], cache.v_aux->ne[2]); + } } else { auto kv_type = cparams.mla_attn == 1 ? cache.type_k : cache.type_v; ggml_tensor * kv = ggml_new_tensor_2d(ctx, kv_type, kv_lora_rank + n_embd_head_qk_rope, kv_size); @@ -13625,6 +13651,111 @@ struct llm_build_context { ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); cb(kv_cache, "kv_cache", il); + ggml_tensor * kqv; + + if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && (pp_opt || lctx.cparams.mla_attn > 2)) { + + // Hahaha, we need to convert the KV cache for this layer to f32 because the general purpose ML library ggml does not + // provide ops on (almost) anything other than f32. In this case, the cache will be the second operand to a matrix + // multiplication, which *must* be f32. + auto kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_kv, kv_self.kv_l[il]->nb[1], 0); + auto kv_cache_view_f32 = ggml_view_2d(ctx0, kv_self.kv_aux_f32, kv_self.kv_aux_f32->ne[0], n_kv, kv_self.kv_aux_f32->nb[1], 0); + kv_cache_view_f32 = ggml_cpy(ctx0, kv_cache_view, kv_cache_view_f32); + + // The no- and rorational position encoding portions of the KV cache + auto kv_cache_nope = ggml_view_2d(ctx0, kv_cache_view_f32, kv_lora_rank, n_kv, kv_cache_view_f32->nb[1], 0); + auto kv_cache_rope = ggml_view_3d(ctx0, kv_cache_view_f32, n_embd_head_qk_rope, 1, n_kv, + kv_cache_view_f32->nb[1], kv_cache_view_f32->nb[1], ggml_row_size(kv_cache_view_f32->type, kv_lora_rank)); + + auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope); + + //// split into {n_head * n_embd_head_qk_nope, n_tokens} + //struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, + // ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + // ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + // 0); + //cb(k_nope, "k_nope", il); + + //// and {n_head * n_embd_head_v, n_tokens} + //struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, + // ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), + // ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), + // ggml_row_size(kv->type, (n_embd_head_qk_nope))); + //cb(v_states, "v_states", il); + + auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head, + ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0); + + ggml_tensor repeater; + repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_head; repeater.ne[2] = n_kv; repeater.ne[3] = 1; + auto k_rope_f32 = ggml_permute(ctx0, ggml_repeat(ctx0, kv_cache_rope, &repeater), 0, 2, 1, 3); + + auto k_f32 = ggml_concat(ctx0, k_nope_f32, k_rope_f32, 0); + + auto k_row_size = ggml_row_size(kv_self.k_aux->type, k_f32->ne[0]); + auto k = ggml_view_3d(ctx0, kv_self.k_aux, k_f32->ne[0], k_f32->ne[1], k_f32->ne[2], + k_row_size, k_row_size*k_f32->ne[1], 0); + //kv_self.k_aux->nb[1], k_row_size, 0); + //k_row_size, kv_self.k_aux->nb[1], 0); + k = ggml_cpy(ctx0, k_f32, k); + + auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head, + ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope)); + + auto v_row_size = ggml_row_size(kv_self.v_aux->type, v_f32->ne[0]); + auto v = ggml_view_3d(ctx0, kv_self.v_aux, v_f32->ne[0], v_f32->ne[1], v_f32->ne[2], + v_row_size, v_row_size*v_f32->ne[1], 0); + //kv_self.v_aux->nb[1], v_row_size, 0); + //v_row_size, kv_self.v_aux->nb[1], 0); + v = ggml_cpy(ctx0, v_f32, v); + + //auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_cache->nb[1], 0); + //auto kv_cache_rope = ggml_view_2d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, kv_cache->nb[1], + // ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)); + ////kv_cache_rope = ggml_permute(ctx0, kv_cache_rope, 0, 2, 1, 3); + + //auto kv_cache_nope_f32 = ggml_view_2d(ctx0, kv_self.kv_aux_f32, kv_lora_rank, n_kv, kv_self.kv_aux_f32->nb[1], 0); + //kv_cache_nope_f32 = ggml_cpy(ctx0, kv_cache_nope, kv_cache_nope_f32); + //auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope_f32); + + ////auto kv = ggml_new_tensor_2d(ctx0, kv_self.kv_l[il]->type, model.layers[il].wkv_b->ne[1], n_kv); + ////auto kv = ggml_new_tensor_2d(ctx0, kv_self.kv_l[il]->type, kv_f32->ne[0], kv_f32->ne[1]); + //auto kv = ggml_view_2d(ctx0, kv_self.kv_aux, kv_self.kv_aux->ne[0], n_kv, kv_self.kv_aux->nb[1], 0); + //kv = ggml_cpy(ctx0, kv_f32, kv); + + //auto k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_kv, n_head, + // ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + // ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0); + + //ggml_tensor repeater; + //repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_head; repeater.ne[3] = 1; + //auto k = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, kv_cache_rope, &repeater), 0); + + //auto v = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_kv, n_head, + // ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + // ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + // ggml_row_size(kv->type, n_embd_head_qk_nope)); + + auto q = ggml_concat(ctx0, q_nope, q_rope, 0); + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + + kqv = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f); + cb(kqv, "kqv", il); + + cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens); + + //kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); + //cb(kqv_compressed, "kqv_compressed_perm", il); + } + else { + + ggml_tensor * kqv_compressed; + + //printf("wkv_b: %ld x %ld x %ld kv_cache: %ld x %ld x %ld\n", model.layers[il].wkv_b->ne[0], model.layers[il].wkv_b->ne[1], model.layers[il].wkv_b->ne[2], kv_cache->ne[0], kv_cache->ne[1], kv_cache->ne[2]); + struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank)*n_embd_head_qk_nope, 0); @@ -13636,12 +13767,12 @@ struct llm_build_context { struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope); cb(q_nope2, "q_nope2", il); + //printf("q_nope2 (%ld x %ld x %ld) = wk_b (%ld x %ld x %ld) * q_nope (%ld x %ld x %ld)\n", q_nope2->ne[0], q_nope2->ne[1], q_nope2->ne[2], + // wk_b->ne[0], wk_b->ne[1], wk_b->ne[2], q_nope->ne[0], q_nope->ne[1], q_nope->ne[2]); ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0); cb(q, "q", il); - ggml_tensor * kqv_compressed; - if (lctx.cparams.flash_attn) { ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, @@ -13729,7 +13860,7 @@ struct llm_build_context { ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank)*n_embd_head_v, 0); cb(wv_b, "wv_b", il); - struct ggml_tensor * kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed); + kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed); cb(kqv, "kqv", il); kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3)); @@ -13737,6 +13868,7 @@ struct llm_build_context { cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0); cb(cur, "kqv_2d", il); + } ggml_build_forward_expand(gf, cur); From b096a5de7a9bdf516bb20729d5d0a3b2a12cba2f Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 9 Mar 2025 16:53:55 +0200 Subject: [PATCH 034/132] This works on CUDA, but (#247) PP speed is great, almost on par with standard FA. But TG speed is pathetic. The strangest thing is that the slowdown is not due to FA, but due to the ffn_gate_exps gemm, which somehow becomes very slow. WTF? As I'm unable the resolve the slow ffn_gate_exps GEMM mystery, for now TG goes via mla=2, PP is via FA. Also discovered the ggml_cast op, so we don't need the aux tensors that I had added to the KV cache. Co-authored-by: Iwan Kawrakow --- src/llama.cpp | 283 +++++++++++++++++++------------------------------- 1 file changed, 104 insertions(+), 179 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index b45d9e70..7d665072 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2691,9 +2691,6 @@ struct llama_kv_cache { // DeepSeek MLA std::vector kv_l; std::vector kvt_l; - ggml_tensor * kv_aux_f32 = nullptr; - ggml_tensor * k_aux = nullptr; - ggml_tensor * v_aux = nullptr; std::vector ctxs; std::vector bufs; @@ -3209,28 +3206,6 @@ static bool llama_kv_cache_init( ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size); ggml_format_name(kv, "cache_kv_l%d", i); cache.kv_l.push_back(kv); - if (cparams.mla_attn > 1 && cache.kv_aux_f32 == nullptr) { - - cache.kv_aux_f32 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, - kv_lora_rank + n_embd_head_qk_rope, kv_size); - //(n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head, kv_size); - ggml_format_name(cache.kv_aux_f32, "kv_aux_f32%d", 0); - - cache.k_aux = ggml_new_tensor_3d(ctx, cache.type_k, hparams.n_embd_head_k, n_head, kv_size); - ggml_format_name(cache.k_aux, "k_aux%d", 0); - - cache.v_aux = ggml_new_tensor_3d(ctx, cache.type_k, hparams.n_embd_head_v, n_head, kv_size); - ggml_format_name(cache.v_aux, "v_aux%d", 0); - - //cache.kv_aux_2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, - // (hparams.n_embd_head_k + hparams.n_embd_head_v)*n_head, kv_size); - //ggml_format_name(cache.kv_aux, "kv_aux%d", 0); - //ggml_format_name(cache.kv_aux_2, "kv_aux%d", 2); - LLAMA_LOG_INFO("%s: allocated kv auxilary tensors as %ld x %ld, %ld x %ld x %ld, %ld x %ld x %ld\n", __func__, - cache.kv_aux_f32->ne[0], cache.kv_aux_f32->ne[1], - cache.k_aux->ne[0], cache.k_aux->ne[1], cache.k_aux->ne[2], - cache.v_aux->ne[0], cache.v_aux->ne[1], cache.v_aux->ne[2]); - } } else { auto kv_type = cparams.mla_attn == 1 ? cache.type_k : cache.type_v; ggml_tensor * kv = ggml_new_tensor_2d(ctx, kv_type, kv_lora_rank + n_embd_head_qk_rope, kv_size); @@ -13659,215 +13634,165 @@ struct llm_build_context { // provide ops on (almost) anything other than f32. In this case, the cache will be the second operand to a matrix // multiplication, which *must* be f32. auto kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_kv, kv_self.kv_l[il]->nb[1], 0); - auto kv_cache_view_f32 = ggml_view_2d(ctx0, kv_self.kv_aux_f32, kv_self.kv_aux_f32->ne[0], n_kv, kv_self.kv_aux_f32->nb[1], 0); - kv_cache_view_f32 = ggml_cpy(ctx0, kv_cache_view, kv_cache_view_f32); + auto kv_cache_view_f32 = ggml_cast(ctx0, kv_cache_view, GGML_TYPE_F32); + cb(kv_cache_view_f32, "kv_cache_view_f32", il); - // The no- and rorational position encoding portions of the KV cache + // The no- and rotational position encoding portions of the KV cache auto kv_cache_nope = ggml_view_2d(ctx0, kv_cache_view_f32, kv_lora_rank, n_kv, kv_cache_view_f32->nb[1], 0); auto kv_cache_rope = ggml_view_3d(ctx0, kv_cache_view_f32, n_embd_head_qk_rope, 1, n_kv, kv_cache_view_f32->nb[1], kv_cache_view_f32->nb[1], ggml_row_size(kv_cache_view_f32->type, kv_lora_rank)); auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope); - - //// split into {n_head * n_embd_head_qk_nope, n_tokens} - //struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, - // ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - // ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), - // 0); - //cb(k_nope, "k_nope", il); - - //// and {n_head * n_embd_head_v, n_tokens} - //struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, - // ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), - // ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), - // ggml_row_size(kv->type, (n_embd_head_qk_nope))); - //cb(v_states, "v_states", il); + cb(kv_f32, "kv_f32", il); auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head, ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0); + cb(k_nope_f32, "k_nope_f32", il); ggml_tensor repeater; repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_head; repeater.ne[2] = n_kv; repeater.ne[3] = 1; auto k_rope_f32 = ggml_permute(ctx0, ggml_repeat(ctx0, kv_cache_rope, &repeater), 0, 2, 1, 3); + cb(k_rope_f32, "k_rope_f32", il); auto k_f32 = ggml_concat(ctx0, k_nope_f32, k_rope_f32, 0); + cb(k_f32, "k_f32", il); - auto k_row_size = ggml_row_size(kv_self.k_aux->type, k_f32->ne[0]); - auto k = ggml_view_3d(ctx0, kv_self.k_aux, k_f32->ne[0], k_f32->ne[1], k_f32->ne[2], - k_row_size, k_row_size*k_f32->ne[1], 0); - //kv_self.k_aux->nb[1], k_row_size, 0); - //k_row_size, kv_self.k_aux->nb[1], 0); - k = ggml_cpy(ctx0, k_f32, k); + auto k = ggml_cast(ctx0, k_f32, kv_self.kv_l[il]->type); + cb(k, "k", il); auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head, ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_row_size(kv_f32->type, n_embd_head_qk_nope)); + cb(v_f32, "v_f32", il); - auto v_row_size = ggml_row_size(kv_self.v_aux->type, v_f32->ne[0]); - auto v = ggml_view_3d(ctx0, kv_self.v_aux, v_f32->ne[0], v_f32->ne[1], v_f32->ne[2], - v_row_size, v_row_size*v_f32->ne[1], 0); - //kv_self.v_aux->nb[1], v_row_size, 0); - //v_row_size, kv_self.v_aux->nb[1], 0); - v = ggml_cpy(ctx0, v_f32, v); - - //auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_cache->nb[1], 0); - //auto kv_cache_rope = ggml_view_2d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, kv_cache->nb[1], - // ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)); - ////kv_cache_rope = ggml_permute(ctx0, kv_cache_rope, 0, 2, 1, 3); - - //auto kv_cache_nope_f32 = ggml_view_2d(ctx0, kv_self.kv_aux_f32, kv_lora_rank, n_kv, kv_self.kv_aux_f32->nb[1], 0); - //kv_cache_nope_f32 = ggml_cpy(ctx0, kv_cache_nope, kv_cache_nope_f32); - //auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope_f32); - - ////auto kv = ggml_new_tensor_2d(ctx0, kv_self.kv_l[il]->type, model.layers[il].wkv_b->ne[1], n_kv); - ////auto kv = ggml_new_tensor_2d(ctx0, kv_self.kv_l[il]->type, kv_f32->ne[0], kv_f32->ne[1]); - //auto kv = ggml_view_2d(ctx0, kv_self.kv_aux, kv_self.kv_aux->ne[0], n_kv, kv_self.kv_aux->nb[1], 0); - //kv = ggml_cpy(ctx0, kv_f32, kv); - - //auto k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_kv, n_head, - // ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), - // ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0); - - //ggml_tensor repeater; - //repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_head; repeater.ne[3] = 1; - //auto k = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, kv_cache_rope, &repeater), 0); - - //auto v = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_kv, n_head, - // ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), - // ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - // ggml_row_size(kv->type, n_embd_head_qk_nope)); + auto v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type); + cb(v, "v", il); auto q = ggml_concat(ctx0, q_nope, q_rope, 0); q = ggml_permute(ctx0, q, 0, 2, 1, 3); + cb(q, "q_concat", il); + + ggml_build_forward_expand(gf, q); kqv = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f); + if (q->ne[1] <= 8) { + ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32); + } cb(kqv, "kqv", il); cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens); - //kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); - //cb(kqv_compressed, "kqv_compressed_perm", il); } else { - ggml_tensor * kqv_compressed; + ggml_tensor * kqv_compressed; - //printf("wkv_b: %ld x %ld x %ld kv_cache: %ld x %ld x %ld\n", model.layers[il].wkv_b->ne[0], model.layers[il].wkv_b->ne[1], model.layers[il].wkv_b->ne[2], kv_cache->ne[0], kv_cache->ne[1], kv_cache->ne[2]); + struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, + ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), + ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank)*n_embd_head_qk_nope, 0); + cb(wk_b, "wk_b", il); - struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, - ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), - ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank)*n_embd_head_qk_nope, 0); - cb(wk_b, "wk_b", il); + q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); + cb(q_nope, "q_nope_perm", il); - q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); - //if (q_nope->ne[1] <= 32) q_nope = ggml_cont(ctx0, q_nope); - cb(q_nope, "q_nope_perm", il); + struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope); + cb(q_nope2, "q_nope2", il); - struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope); - cb(q_nope2, "q_nope2", il); - //printf("q_nope2 (%ld x %ld x %ld) = wk_b (%ld x %ld x %ld) * q_nope (%ld x %ld x %ld)\n", q_nope2->ne[0], q_nope2->ne[1], q_nope2->ne[2], - // wk_b->ne[0], wk_b->ne[1], wk_b->ne[2], q_nope->ne[0], q_nope->ne[1], q_nope->ne[2]); + ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0); + cb(q, "q", il); - ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0); - cb(q, "q", il); - - if (lctx.cparams.flash_attn) { - ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], - kv_lora_rank, n_kv, - ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); - cb(kv_cache_lora, "kv_cache_lora", il); - - //ggml_tensor * v = ggml_cont(ctx0, kv_cache_lora); - //kqv_compressed = ggml_flash_attn_ext(ctx0, q, kv_cache, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f); - - kqv_compressed = ggml_flash_attn_ext(ctx0, q, kv_cache, kv_cache_lora, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f); - cb(kqv_compressed, "kqv_compressed", il); - - kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); - cb(kqv_compressed, "kqv_compressed_perm", il); - } - else { - if (lctx.cparams.mla_attn > 1) { + if (lctx.cparams.flash_attn && lctx.cparams.mla_attn == 1) { ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); - cb(kv_cache, "kv_cache_lora", il); + cb(kv_cache_lora, "kv_cache_lora", il); - kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora)); - cb(kv_cache_trans, "kv_cache_trans", il); - } - - auto kq_size = kv_cache->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); // K*Q in MiB - if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kq_size) { - if (!pp_opt) { - q = ggml_permute(ctx0, q, 0, 2, 1, 3); - cb(q, "q_perm", il); - } - - ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q); - cb(kq, "kq", il); - - if (!pp_opt) { - kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); - cb(kq, "kq_perm", il); - } - - kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); - cb(kq, "kq_soft_max_ext", il); - - if (!pp_opt) { - kq = ggml_permute(ctx0, kq, 0, 2, 1, 3); - cb(kq, "kq_soft_max_ext_perm", il); - } - - kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq); + kqv_compressed = ggml_flash_attn_ext(ctx0, q, kv_cache, kv_cache_lora, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f); cb(kqv_compressed, "kqv_compressed", il); - if (!pp_opt) { - kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); - cb(kqv_compressed, "kqv_compressed_perm", il); + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); + cb(kqv_compressed, "kqv_compressed_perm", il); + } + else { + if (lctx.cparams.mla_attn > 1) { + ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], + kv_lora_rank, n_kv, + ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); + cb(kv_cache, "kv_cache_lora", il); + + kv_cache_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_cache_lora)); + cb(kv_cache_trans, "kv_cache_trans", il); } - } else { - - int n_step = (kq_size + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch; - n_step = std::min(n_step, int(q->ne[2])); - int n_per_step = (q->ne[2] + n_step - 1)/n_step; - - //printf("kq size would be %ld MiB -> splitting kqv computation into %d steps\n", kq_size, n_step); - - for (int i_head = 0; i_head < q->ne[2]; i_head += n_per_step) { - int this_ne12 = i_head + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i_head; - ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i_head); - ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i); - kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias); - ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i); - if (i_head == 0) { - kqv_compressed = kqv_i; - } else { - kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2); + auto kq_size = kv_cache->ne[1]*q->ne[1]*q->ne[2]*sizeof(float)/(1024*1024); // K*Q in MiB + if (lctx.cparams.attn_max_batch <= 0 || lctx.cparams.attn_max_batch >= kq_size) { + if (!pp_opt) { + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + cb(q, "q_perm", il); } - ggml_build_forward_expand(gf, kqv_compressed); + + ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q); + cb(kq, "kq", il); + + if (!pp_opt) { + kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3)); + cb(kq, "kq_perm", il); + } + + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + if (!pp_opt) { + kq = ggml_permute(ctx0, kq, 0, 2, 1, 3); + cb(kq, "kq_soft_max_ext_perm", il); + } + + kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq); + cb(kqv_compressed, "kqv_compressed", il); + + if (!pp_opt) { + kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3); + cb(kqv_compressed, "kqv_compressed_perm", il); + } + + } else { + + int n_step = (kq_size + lctx.cparams.attn_max_batch - 1)/lctx.cparams.attn_max_batch; + n_step = std::min(n_step, int(q->ne[2])); + int n_per_step = (q->ne[2] + n_step - 1)/n_step; + + for (int i_head = 0; i_head < q->ne[2]; i_head += n_per_step) { + int this_ne12 = i_head + n_per_step <= q->ne[2] ? n_per_step : q->ne[2] - i_head; + ggml_tensor * q_i = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i_head); + ggml_tensor * kq_i = ggml_mul_mat(ctx0, kv_cache, q_i); + kq_i = ggml_soft_max_ext(ctx0, kq_i, KQ_mask, kq_scale, hparams.f_max_alibi_bias); + ggml_tensor * kqv_i = ggml_mul_mat(ctx0, kv_cache_trans, kq_i); + if (i_head == 0) { + kqv_compressed = kqv_i; + } else { + kqv_compressed = ggml_concat(ctx0, kqv_compressed, kqv_i, 2); + } + ggml_build_forward_expand(gf, kqv_compressed); + } + cb(kqv_compressed, "kqv_compressed", il); } - cb(kqv_compressed, "kqv_compressed", il); } - } - struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, - ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), - ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank)*n_embd_head_v, 0); - cb(wv_b, "wv_b", il); + struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, + ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), + ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank)*n_embd_head_v, 0); + cb(wv_b, "wv_b", il); - kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed); - cb(kqv, "kqv", il); + kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed); + cb(kqv, "kqv", il); - kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3)); - cb(kqv, "kqv_perm", il); + kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3)); + cb(kqv, "kqv_perm", il); - cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0); - cb(cur, "kqv_2d", il); + cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0); + cb(cur, "kqv_2d", il); } ggml_build_forward_expand(gf, cur); From 699c9cb7f63dd8431bce91b86e10efb41255f6c1 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 10 Mar 2025 16:16:51 +0200 Subject: [PATCH 035/132] Faster MoE token generation on CUDA (#248) * This gives us ~20% TG speedup for DeepSeek on CUDA * Slightly better * Also do it for plain (not fused) mul_mat_id * Guard against numerical precision issues for MLA on CUDA --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda.cu | 310 ++++++++++++++++++++++++++++---- ggml/src/ggml-cuda/iqk_mmvq.cu | 88 ++++----- ggml/src/ggml-cuda/iqk_mmvq.cuh | 40 ++--- ggml/src/ggml-cuda/mmvq.cu | 250 ++++++++++++++------------ ggml/src/ggml-cuda/mmvq.cuh | 6 + src/llama.cpp | 3 + 6 files changed, 488 insertions(+), 209 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 410c6406..f25dd725 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -1765,6 +1765,93 @@ static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const gg ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream); } +/* +static void ggml_cuda_op_gemv_id( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src0_ids, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op, + quantize_cuda_t quantize_src1) { + + GGML_ASSERT(src0->ne[3] == 1); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_nrows(src1) == 1); + GGML_ASSERT(src0_ids->ne[1] == 1); + GGML_ASSERT(src0_ids->ne[0] <= dst->ne[2]); + GGML_ASSERT(dst->ne[1] == 1); + GGML_ASSERT(src0->ne[0] == src1->ne[0]); + + GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer)); + GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer)); + GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer)); + + ggml_backend_cuda_buffer_context * src0_ctx = (ggml_backend_cuda_buffer_context *) src0->buffer->context; + ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context; + ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context; + + int device_id = ctx.device; + GGML_ASSERT(src0_ctx->device == device_id); + GGML_ASSERT(src1_ctx->device == device_id); + GGML_ASSERT(dst_ctx->device == device_id); + + const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer); + GGML_ASSERT(!split); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + + const int64_t ne10 = src1->ne[0]; + const int64_t nrows1 = 1; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne2 = dst->ne[2]; + + const int64_t nb2 = dst->nb[2]; + + // Why? + GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1)); + + const size_t src0_rs = ggml_row_size(src0->type, ne00); + const size_t q8_1_ts = sizeof(block_q8_1); + const size_t q8_1_bs = QK8_1; + + const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING); + + ggml_cuda_pool_alloc src0_dd_alloc; + ggml_cuda_pool_alloc src1_ddf_alloc; + ggml_cuda_pool_alloc src1_ddq_alloc; + ggml_cuda_pool_alloc dst_dd_alloc; + + char * src0_dd = nullptr; + float * src1_ddf = (float *)src1->data; + char * src1_ddq = nullptr; // q8_1 + float * dst_dd = (float *)dst->data; + + bool quantization_done = false; + + const bool src1_on_device = device_id == src1_ctx->device; + const bool dst_on_device = device_id == dst_ctx->device; + + ggml_cuda_set_device(device_id); + cudaStream_t stream = ctx.stream(device_id, 0); + + src0_dd = (char *) src0->data; + + if (quantize_src1) { + size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs; + src1_ddq = src1_ddq_alloc.alloc(ctx.pool(device_id), src_1_ddq_size); + quantize_src1(src1_ddf, src1_ddq, ne10, 1, 1, src1_padded_col_size, src0->type, stream); + } + + ggml_cuda_op_mul_mat_vec_q_id(ctx, src0, src1, src0_ids, dst, + (const char *)src0->data, (const float *)src1->data, src1_ddq, (float *)dst->data, + 0, ne01, 1, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + +} +*/ + static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); @@ -2090,6 +2177,52 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * ids = dst->src[2]; + if (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1 && + ggml_is_quantized(src0->type) && + ggml_backend_buffer_is_cuda(src0->buffer) && + ggml_backend_buffer_is_cuda(src1->buffer) && + ggml_backend_buffer_is_cuda(dst->buffer) && + !ggml_backend_buffer_is_cuda_split(src0->buffer) && + src1->type == GGML_TYPE_F32) { + int device_id = ctx.device; + ggml_backend_cuda_buffer_context * src0_ctx = (ggml_backend_cuda_buffer_context *) src0->buffer->context; + ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context; + ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context; + if (src0_ctx->device == device_id && + src1_ctx->device == device_id && + dst_ctx->device == device_id) { + GGML_ASSERT(src1->ne[0] % QK8_1 == 0); + // Fast TG path + const int64_t n_ids = ids->ne[0]; + auto stream = ctx.stream(device_id, 0); + + auto local_dst = *dst; + local_dst.ne[2] = n_ids; + local_dst.ne[1] = local_dst.ne[3] = 1; + local_dst.nb[2] = local_dst.nb[1]; + + auto local_src1 = *src1; + local_src1.nb[2] = local_src1.nb[3] = 0; + + const int64_t src1_padded_col_size = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING); + ggml_cuda_pool_alloc src1_quantized(ctx.pool()); + auto src_1_ddq_size = src1_padded_col_size*sizeof(block_q8_1)/QK8_1; + local_src1.data = src1_quantized.alloc(src_1_ddq_size); + quantize_row_q8_1_cuda((const float *)src1->data, (void *)src1_quantized.get(), src1->ne[0], 1, 1, src1_padded_col_size, + src0->type, stream); + CUDA_CHECK(cudaGetLastError()); + + local_src1.nb[1] = src_1_ddq_size; + + ggml_cuda_op_mul_mat_vec_q_id(ctx, src0, &local_src1, ids, &local_dst, + (const char *)src0->data, nullptr, src1_quantized.get(), (float *)dst->data, + 0, src0->ne[1], 1, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + + return; + } + } + GGML_TENSOR_BINARY_OP_LOCALS GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers"); @@ -2232,6 +2365,121 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor const ggml_tensor * src1 = dst->src[2]; const ggml_tensor * ids = dst->src[3]; + if (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1 && + ggml_is_quantized(src0_1->type) && + ggml_is_quantized(src0_2->type) && + ggml_backend_buffer_is_cuda(src0_1->buffer) && + ggml_backend_buffer_is_cuda(src0_2->buffer) && + ggml_backend_buffer_is_cuda(src1->buffer) && + ggml_backend_buffer_is_cuda(dst->buffer) && + !ggml_backend_buffer_is_cuda_split(src0_1->buffer) && + !ggml_backend_buffer_is_cuda_split(src0_2->buffer) && + src1->type == GGML_TYPE_F32) { + int device_id = ctx.device; + ggml_backend_cuda_buffer_context * src0_1_ctx = (ggml_backend_cuda_buffer_context *) src0_1->buffer->context; + ggml_backend_cuda_buffer_context * src0_2_ctx = (ggml_backend_cuda_buffer_context *) src0_2->buffer->context; + ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context; + ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context; + if (src0_1_ctx->device == device_id && + src0_2_ctx->device == device_id && + src1_ctx->device == device_id && + dst_ctx->device == device_id) { + // Fast TG path + const int64_t n_ids = ids->ne[0]; + auto stream = ctx.stream(device_id, 0); + ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*dst->ne[0]*n_ids); + ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*dst->ne[0]*n_ids); + + auto local_dst = *dst; + local_dst.ne[2] = n_ids; + local_dst.ne[1] = local_dst.ne[3] = 1; + local_dst.nb[1] = local_dst.nb[2] = local_dst.nb[3] = local_dst.ne[0]*sizeof(float); + + auto local_src1 = *src1; + local_src1.nb[2] = local_src1.nb[3] = 0; + + const int64_t src1_padded_col_size = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING); + ggml_cuda_pool_alloc src1_quantized(ctx.pool()); + if (ggml_is_quantized(src0_1->type) || ggml_is_quantized(src0_2->type)) { + GGML_ASSERT(src1->ne[0] % QK8_1 == 0); + auto src_1_ddq_size = src1_padded_col_size*sizeof(block_q8_1)/QK8_1; + local_src1.data = src1_quantized.alloc(src_1_ddq_size); + // Note: no use is currently made of the quantization type passed into quantize_row_q8_1_cuda. + // If that were to change, we would need to adjust the code to handle src0_1->type != src0_2->type + quantize_row_q8_1_cuda((const float *)src1->data, (void *)src1_quantized.get(), src1->ne[0], 1, 1, src1_padded_col_size, + src0_1->type, stream); + CUDA_CHECK(cudaGetLastError()); + + local_src1.nb[1] = src_1_ddq_size; + } + + local_dst.data = dst_up_contiguous.get(); + ggml_cuda_op_mul_mat_vec_q_id(ctx, src0_1, &local_src1, ids, &local_dst, + (const char *)src0_1->data, (const float *)src1->data, src1_quantized.get(), (float *)dst_up_contiguous.get(), + 0, src0_1->ne[1], 1, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + + local_dst.data = dst_gate_contiguous.get(); + ggml_cuda_op_mul_mat_vec_q_id(ctx, src0_2, &local_src1, ids, &local_dst, + (const char *)src0_2->data, (const float *)src1->data, src1_quantized.get(), (float *)dst_gate_contiguous.get(), + 0, src0_2->ne[1], 1, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + + if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) && + ggml_backend_buffer_is_cuda(next->src[0]->buffer) && + !ggml_backend_buffer_is_cuda_split(next->src[0]->buffer) && + ((ggml_backend_cuda_buffer_context *)next->src[0]->buffer->context)->device == device_id && + ggml_backend_buffer_is_cuda(next->buffer) && + ((ggml_backend_cuda_buffer_context *)next->buffer->context)->device == device_id) { + + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst->ne[0]*n_ids, + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); + CUDA_CHECK(cudaGetLastError()); + + const int64_t dst_padded_col_size = GGML_PAD(dst->ne[0], MATRIX_ROW_PADDING); + GGML_ASSERT(dst->ne[0] % QK8_1 == 0); + auto dst_row_size = dst_padded_col_size*sizeof(block_q8_1)/QK8_1; + auto dst_ddq_size = n_ids*dst_row_size; + ggml_cuda_pool_alloc dst_quantized(ctx.pool(), dst_ddq_size); + quantize_row_q8_1_cuda((const float *)dst_gate_contiguous.get(), (void *)dst_quantized.get(), dst->ne[0], n_ids, 1, + dst_padded_col_size, next->src[0]->type, stream); + CUDA_CHECK(cudaGetLastError()); + + std::vector ids_host(ggml_nbytes(ids)); + const char * ids_dev = (const char *) ids->data; + CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + local_dst.ne[2] = 1; + + auto local_next = *next; + local_next.ne[2] = local_next.ne[1]; + local_next.ne[1] = local_next.ne[3] = 1; + local_next.nb[2] = local_next.nb[1]; + + local_src1 = *next->src[1]; + local_src1.ne[1] = local_src1.ne[2] = local_src1.ne[3] = 1; + local_src1.nb[1] = local_src1.nb[2] = local_src1.nb[3] = dst_row_size; + + auto local_src0 = *next->src[0]; + local_src0.ne[2] = local_src0.ne[3] = 1; + + ggml_cuda_op_mul_mat_vec_q_id(ctx, &local_src0, &local_src1, ids, &local_next, + (const char *)next->src[0]->data, nullptr, dst_quantized.get(), (float *)next->data, + 0, next->src[0]->ne[1], 1, dst_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + + return true; + } else { + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst), + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data); + CUDA_CHECK(cudaGetLastError()); + return false; + } + } + } + + GGML_TENSOR_BINARY_OP_LOCALS GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0_1->buffer) && "mul_mat_id does not support split buffers"); @@ -2299,49 +2547,47 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor if (fuse_down) { final_dst.src[1] = &dst_row; } - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + for (int64_t id = 0; id < n_ids; id++) { + const int32_t i02 = *(const int32_t *) (ids_host.data() + id*ids->nb[0]); - if (i02 < 0 || i02 >= n_as) continue; - //GGML_ASSERT(i02 >= 0 && i02 < n_as); + if (i02 < 0 || i02 >= n_as) continue; + //GGML_ASSERT(i02 >= 0 && i02 < n_as); - const int64_t i11 = id % ne11; - const int64_t i12 = iid1; + const int64_t i11 = id % ne11; + const int64_t i12 = 0; - const int64_t i1 = id; - const int64_t i2 = i12; + const int64_t i1 = id; + const int64_t i2 = i12; - src0_1_row.data = src0_1_original + i02*nb02; - src0_2_row.data = src0_2_original + i02*nb02; - src1_row.data = src1_original + i11*nb11 + i12*nb12; - //dst_row.data = dst_original + i1*nb1 + i2*nb2; + src0_1_row.data = src0_1_original + i02*nb02; + src0_2_row.data = src0_2_original + i02*nb02; + src1_row.data = src1_original + i11*nb11 + i12*nb12; + //dst_row.data = dst_original + i1*nb1 + i2*nb2; - dst_row.data = dst_up_contiguous.get(); - ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row); + dst_row.data = dst_up_contiguous.get(); + ggml_cuda_mul_mat(ctx, &src0_1_row, &src1_row, &dst_row); + CUDA_CHECK(cudaGetLastError()); + + dst_row.data = dst_gate_contiguous.get(); + ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row); + CUDA_CHECK(cudaGetLastError()); + + if (fuse_down) { + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0], + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); CUDA_CHECK(cudaGetLastError()); - dst_row.data = dst_gate_contiguous.get(); - ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row); + final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2]; + final_dst.data = (char *)next->data + i1*next->nb[1] + i2*next->nb[2]; + ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst); CUDA_CHECK(cudaGetLastError()); - if (fuse_down) { - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0], - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); - CUDA_CHECK(cudaGetLastError()); + } else { - final_src.data = (char *)next->src[0]->data + i02*next->src[0]->nb[2]; - final_dst.data = (char *)next->data + i1*next->nb[1] + i2*next->nb[2]; - ggml_cuda_mul_mat(ctx, &final_src, &dst_row, &final_dst); - CUDA_CHECK(cudaGetLastError()); + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0], + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)(dst_original + i1*nb1 + i2*nb2)); + CUDA_CHECK(cudaGetLastError()); - } else { - - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst_row.ne[0], - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)(dst_original + i1*nb1 + i2*nb2)); - CUDA_CHECK(cudaGetLastError()); - - } } } } else { diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index 2eafe463..576c387d 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -96,11 +96,13 @@ template __launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __global__ void iqk_mul_mat_vec_q( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const char * __restrict__ ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size, - const uint64_t nb02, const uint64_t nb12, const uint64_t nb2) { + const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0) { int i2 = blockIdx.y; - const char * cx = (const char *)vx + i2*nb02; + int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2; + if (i02 < 0) return; + const char * cx = (const char *)vx + i02*nb02; const char * cy = (const char *)vy + i2*nb12; char * cdst = (char *)dst + i2*nb2; iqk_mul_mat_vec_q(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size); @@ -108,9 +110,9 @@ __global__ void iqk_mul_mat_vec_q( template void iqk_mul_mat_vec_q_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); //GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE); @@ -152,28 +154,28 @@ void iqk_mul_mat_vec_q_cuda( switch (ncols_y) { case 1: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; case 2: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; case 3: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; case 4: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; case 5: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; case 6: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; case 7: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; case 8: - iqk_mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2); + iqk_mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, row_size, nb02, nb12, nb2, ids_nb0); break; default: GGML_ASSERT(false); @@ -742,79 +744,79 @@ __device__ __forceinline__ float vec_dot_iq2_bn_q8_1( } // namespace void mul_mat_vec_iq2_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq3_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq4_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq4_ks_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq4_kss_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq2_ks_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq5_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq6_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq1_bn_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } void mul_mat_vec_iq2_bn_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { - iqk_mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { + iqk_mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index a128bc53..15df20f5 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -1,51 +1,51 @@ #include "common.cuh" void mul_mat_vec_iq2_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq3_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq4_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq5_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq6_k_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq4_ks_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq4_kss_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq2_ks_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq1_bn_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); void mul_mat_vec_iq2_bn_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream); + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index e17e77a3..b9e9c216 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -134,11 +134,12 @@ template __launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static __global__ void mul_mat_vec_q( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const char * __restrict__ ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, - const uint64_t nb02, const uint64_t nb12, const uint64_t nb2) { + const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0) { int i2 = blockIdx.y; - const char * cx = (const char *)vx + i2*nb02; + int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2; + const char * cx = (const char *)vx + i02*nb02; const char * cy = (const char *)vy + i2*nb12; char * cdst = (char *)dst + i2*nb2; mul_mat_vec_q(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst); @@ -146,9 +147,9 @@ static __global__ void mul_mat_vec_q( template static void mul_mat_vec_q_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream) { GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE); @@ -188,28 +189,28 @@ static void mul_mat_vec_q_cuda( switch (ncols_y) { case 1: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 2: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 3: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 4: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 5: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 6: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 7: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 8: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; default: GGML_ABORT("fatal error"); @@ -218,169 +219,169 @@ static void mul_mat_vec_q_cuda( } static void mul_mat_vec_q4_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q4_1_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q5_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q5_1_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q6_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q8_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q2_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q3_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q4_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q5_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_q6_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq2_xxs_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq2_xs_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq2_s_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq3_xxs_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq1_s_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq1_m_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq4_nl_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq4_xs_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void mul_mat_vec_iq3_s_q8_1_cuda( - const void * vx, const void * vy, float * dst, + const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, cudaStream_t stream) { + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); } static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggml_type type, const int64_t ne00, const int64_t ne0, const int64_t ne2, - const int64_t nb02, const int64_t nb12, const int64_t nb2, - const char * src0_dd_i, const char * src1_ddq_i, float * dst_dd_i, + const int64_t nb02, const int64_t nb12, const int64_t nb2, const int64_t ids_nb0, + const char * src0_dd_i, const char * src1_ddq_i, float * dst_dd_i, const char * ids_data, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream) { @@ -391,98 +392,97 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm // the main device has a larger memory buffer to hold the results from all GPUs // nrows_dst == nrows of the matrix that the kernel writes into const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; - const int64_t src1_row_size = ggml_row_size(GGML_TYPE_Q8_1, src1_padded_row_size); switch (type) { case GGML_TYPE_Q4_0: - mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_Q4_1: - mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_Q5_0: - mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_Q5_1: - mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_Q6_0: - mul_mat_vec_q6_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q6_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_Q8_0: - mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_Q2_K: - mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_Q3_K: - mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_Q4_K: - mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_Q5_K: - mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_Q6_K: - mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ2_XXS: - mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ2_XS: - mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ2_S: - mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ3_XXS: - mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ1_S: - mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ1_M: - mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ1_BN: - mul_mat_vec_iq1_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq1_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ2_BN: - mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ4_NL: - mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ4_XS: - mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ2_K: - mul_mat_vec_iq2_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq2_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ3_K: - mul_mat_vec_iq3_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq3_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ4_K: - mul_mat_vec_iq4_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq4_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ4_KS: - mul_mat_vec_iq4_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq4_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ4_KSS: - mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ2_KS: - mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ5_K: - mul_mat_vec_iq5_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq5_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ6_K: - mul_mat_vec_iq6_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq6_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; case GGML_TYPE_IQ3_S: - mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, stream); + mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; default: GGML_ABORT("fatal error"); @@ -505,14 +505,12 @@ void ggml_cuda_op_mul_mat_vec_q_3D( const int64_t ne0 = dst->ne[0]; - int id = ggml_cuda_get_device(); - const int64_t src1_row_size = ggml_row_size(GGML_TYPE_Q8_1, src1_padded_row_size); ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type, ne00, ne0, dst->ne[2], - src0->nb[2], src1_row_size, dst->nb[2], - src0_dd_i, src1_ddq_i, dst_dd_i, + src0->nb[2], src1_row_size, dst->nb[2], 0, + src0_dd_i, src1_ddq_i, dst_dd_i, nullptr, row_low, row_high, src1_ncols, src1_padded_row_size, stream); @@ -531,11 +529,35 @@ void ggml_cuda_op_mul_mat_vec_q( const int64_t ne0 = dst->ne[0]; - int id = ggml_cuda_get_device(); - ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type, - ne00, ne0, 1, 0, 0, 0, - src0_dd_i, src1_ddq_i, dst_dd_i, + ne00, ne0, 1, 0, 0, 0, 0, + src0_dd_i, src1_ddq_i, dst_dd_i, nullptr, + row_low, row_high, src1_ncols, + src1_padded_row_size, stream); + + GGML_UNUSED(src1_ddf_i); +} + +void ggml_cuda_op_mul_mat_vec_q_id( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, + const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + + const int64_t ne00 = src0->ne[0]; + const int64_t ne10 = src1->ne[0]; + GGML_ASSERT(ne10 % QK8_1 == 0); + GGML_ASSERT(src0->ne[3] == 1 && src1->ne[3] == 1 && dst->ne[3] == 1); + GGML_ASSERT(src1->ne[1] == 1 && src1->ne[2] == 1); + GGML_ASSERT(ids->ne[0] == dst->ne[2]); + + const int64_t ne0 = dst->ne[0]; + + ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type, + ne00, ne0, dst->ne[2], + src0->nb[2], src1->nb[2], dst->nb[2], ids->nb[0], + src0_dd_i, src1_ddq_i, dst_dd_i, (const char *)ids->data, row_low, row_high, src1_ncols, src1_padded_row_size, stream); diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index f86aefe2..c0699b04 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -11,3 +11,9 @@ void ggml_cuda_op_mul_mat_vec_q_3D(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); + +void ggml_cuda_op_mul_mat_vec_q_id(ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, + const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream); diff --git a/src/llama.cpp b/src/llama.cpp index 7d665072..bad8d33d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -13734,6 +13734,9 @@ struct llm_build_context { } ggml_tensor * kq = ggml_mul_mat(ctx0, kv_cache, q); + if (kv_cache->ne[1] < 256) { + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } cb(kq, "kq", il); if (!pp_opt) { From a48e16324770bb829406d06e11be1df0c8a3b517 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 10 Mar 2025 16:19:09 +0200 Subject: [PATCH 036/132] DeepSeek imatrix stuff (#250) * This gives us ~20% TG speedup for DeepSeek on CUDA * Slightly better * Also do it for plain (not fused) mul_mat_id * Guard against numerical precision issues for MLA on CUDA * imatrix: wv_b <-> wkv_b --------- Co-authored-by: Iwan Kawrakow --- examples/imatrix/imatrix.cpp | 2 +- ggml/src/ggml-cuda/cpy.cu | 2 +- ggml/src/ggml.c | 2 +- src/llama.cpp | 18 ++++++++++++++++++ 4 files changed, 21 insertions(+), 3 deletions(-) diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 8006988c..d8a43049 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -195,7 +195,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * if (m_params.verbosity > 1) { printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type); } - for (int row = 0; row < (int)src1->ne[1]; ++row) { + for (int row = 0; row < (int)(src1->ne[1]*src1->ne[2]); ++row) { const float * x = data + row * src1->ne[0]; for (int j = 0; j < (int)src1->ne[0]; ++j) { e.values[j] += x[j]*x[j]; diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 0b269a86..fabe8843 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -556,7 +556,7 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q6_0) { return (void*) cpy_f32_q; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - return (void*) cpy_f32_f16; + return (void*) cpy_f32_f16; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { return (void*) cpy_f32_f16; } else { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 4089e9b7..88820438 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -10468,7 +10468,7 @@ static void ggml_compute_forward_dup_bytes( if (ggml_is_contiguous(dst)) { size_t id = 0; char * dst_ptr = (char *) dst->data; - const size_t rs = ne00 * type_size; + const size_t rs = ggml_row_size(src0->type, ne00); //ne00 * type_size; if (nb00 == type_size) { // src0 is contigous on first dimension, copy by rows diff --git a/src/llama.cpp b/src/llama.cpp index bad8d33d..ba5c5052 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -13787,6 +13787,7 @@ struct llm_build_context { ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank)*n_embd_head_v, 0); cb(wv_b, "wv_b", il); + std::memcpy(wv_b->name, model.layers[il].wv_b->name, GGML_MAX_NAME); kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed); cb(kqv, "kqv", il); @@ -17347,6 +17348,23 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s const float * imatrix = nullptr; if (imatrix_data) { auto it = imatrix_data->find(tensor->name); + if (it == imatrix_data->end()) { + // MLA hack: most imatrix files floating around the Internet have been computed with standard attention. + // This means that the imatrix file does not contain data for the *.attn_k_b.weight and *.attn_v_b.weight + // required by MLA. But the *.attn_v_b.weight tensors "see" the exact same activations as the + // *.attn_kv_b.weight tensors used in standard attention. Hence, if we find imatrix data for + // *.attn_kv_b.weight we can use it for *.attn_v_b.weight and vice versa. + std::string name{tensor->name}; + static std::array alternatives{".attn_v_b.weight", ".attn_kv_b.weight"}; + for (int j = 0; j < int(alternatives.size()); ++j) { + if (auto pos = name.find(alternatives[j]); pos != std::string::npos) { + int j1 = (j + 1) % alternatives.size(); + auto alternative_name = name.substr(0, pos) + alternatives[j1]; + it = imatrix_data->find(alternative_name); + break; + } + } + } if (it == imatrix_data->end()) { LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name); } else { From 3f23ed68f17583a8ee63afd0c214f5b39226226c Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 12 Mar 2025 07:21:46 +0200 Subject: [PATCH 037/132] MLA-2: Allow usage of q8_0 for KV cache on CUDA (#252) * FlashMLA(CUDA): WIP to allow q8_0 quantized cache * WIP * FlashMLA(CUDA) - allow q8_0 for KV cache This works, and PP is not bad, but TG is still quite a bit slower. * FlashMLA(CUDA) - allow q8_0 for KV cache This is better. ~9% slower than f16 cache for short contexts, nearly on par at 16k tokens. --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda.cu | 11 ++- ggml/src/ggml-cuda/cpy.cu | 137 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 141 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index f25dd725..1bb869c3 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2296,9 +2296,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * for (int64_t id = 0; id < n_ids; id++) { const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - if (i02 < 0 || i02 >= n_as) continue; - //GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); - if (row_id_i != i02) { continue; } @@ -3458,6 +3455,14 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { return true; } + if (ggml_is_contiguous(op->src[0]) && ggml_are_same_shape(op->src[0], op->src[1])) { + if (src1_type == GGML_TYPE_F16 || src1_type == GGML_TYPE_BF16 || src1_type == GGML_TYPE_F32) { + return true; + } + } + if (ggml_are_same_shape(op->src[0], op->src[1]) && op->src[0]->type == GGML_TYPE_Q8_0 && op->src[1]->type == GGML_TYPE_Q8_0) { + return true; + } return false; } break; case GGML_OP_DUP: diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index fabe8843..44eba389 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -1,4 +1,5 @@ #include "cpy.cuh" +#include "convert.cuh" typedef void (*cpy_kernel_t)(const char * cx, char * cdst); @@ -65,6 +66,66 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, cpy_1(cx + x_offset, cdst + dst_offset); } +//static __global__ void cpy_q8_0_f32(const char * cx, float * dst, const int ne, +// const int ne00, const int ne01, const int ne02, const int nb01, const int nb02, const int nb03) { +// const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; +// +// if (i >= ne) { +// return; +// } +// +// const int64_t i03 = i/(ne00 * ne01 * ne02); +// const int64_t i02 = (i - i03*ne00*ne01*ne02) / (ne00*ne01); +// const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00; +// const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00; +// +// const block_q8_0 * q8 = (const block_q8_0 *)(cx + i01*nb01 + i02*nb02 + i03*nb03); +// const int ib = i00/QK8_0; +// const int iq = i00%QK8_0; +// +// dst[i00*ne01 + i01 + i02*ne00*ne01 + i03*ne00*ne01*ne02] = __half2float(q8[ib].d)*q8[ib].qs[iq]; +//} + +static __global__ void k_transpose_q8_0(const char * cx, char * cdst, + const int ne10, const int ne11, const int ne12, + const int nb01, const int nb02, const int nb03, + const int nb11, const int nb12, const int nb13) { + const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; + + const int64_t i13 = i/(ne10 * ne11 * ne12); + const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); + const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; + const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; + + //const int64_t ne00 = ne11; + //const int64_t ne01 = ne10; + //const int64_t ne02 = ne12; + const int64_t i03 = i13; + const int64_t i02 = i12; + const int64_t i01 = i10; //(i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00; + const int64_t i00 = i11; //i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00; + + const block_q8_0 * q8 = (const block_q8_0 *)(cx + i01*nb01 + i02*nb02 + i03*nb03); + const int ib0 = i00/QK8_0; + const int iq0 = i00%QK8_0; + + float xi = __half2float(q8[ib0].d)*q8[ib0].qs[iq0]; + float amax = fabsf(xi); + amax = warp_reduce_max(amax); + + //printf("%d, %d, %d: i = %ld, i11 = %ld i10 = %ld, xi = %g, amax = %g\n", blockDim.x, blockIdx.x, threadIdx.x, i, i11, i10, xi, amax); + + float d = amax/127; + int8_t q = amax == 0.0f ? 0 : roundf(xi / d); + + block_q8_0 * dst = (block_q8_0 *)(cdst + i11*nb11 + i12*nb12 + i13*nb13); + dst[i10 / QK8_0].qs[i10 % QK8_0] = q; + + if (threadIdx.x == 0) { + dst[i10 / QK8_0].d = __float2half(d); + } +} + static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { const float * xi = (const float *) cxi; block_q8_0 * dsti = (block_q8_0 *) cdsti; @@ -464,6 +525,32 @@ static void ggml_cpy_f16_f16_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } +static void transpose_q8_0(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) { + auto stream = ctx.stream(); + auto num_blocks = ggml_nelements(dst)/QK8_0; + k_transpose_q8_0<<>>( + (const char *)src->data, (char *)dst->data, + dst->ne[0], dst->ne[1], dst->ne[2], src->nb[0], src->nb[2], src->nb[3], + dst->nb[1], dst->nb[2], dst->nb[3]); + + //auto ne = ggml_nelements(dst); + //ggml_cuda_pool_alloc dst_f32(ctx.pool(), ne); + //const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + //auto aux_src = *dst; + //aux_src.nb[0] = sizeof(float); + //aux_src.nb[1] = aux_src.nb[0]*aux_src.ne[0]; + //aux_src.nb[2] = aux_src.nb[1]*aux_src.ne[1]; + //aux_src.nb[3] = aux_src.nb[2]*aux_src.ne[2]; + //cpy_q8_0_f32<<>> + // ((const char *)src->data, dst_f32.get(), ne, + // src->ne[1], src->ne[0], src->ne[2], src->nb[0], src->nb[2], src->nb[3]); + //CUDA_CHECK(cudaGetLastError()); + //aux_src.type = GGML_TYPE_F32; + //ggml_cpy_f32_q8_0_cuda((const char *)dst_f32.get(), (char *)dst->data, ne, dst->ne[0], dst->ne[1], dst->ne[2], + // aux_src.nb[0], aux_src.nb[1], aux_src.nb[2], aux_src.nb[3], + // dst->ne[0], dst->ne[1], dst->ne[2], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], stream); +} + void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); @@ -522,9 +609,33 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (ggml_is_contiguous(src0) && ggml_are_same_shape(src0, src1)) { + if (src1->type == GGML_TYPE_F16) { + auto to_fp16 = ggml_get_to_fp16_cuda(src0->type); + if (to_fp16) { + to_fp16(src0->data, (half *)src1->data, ggml_nrows(src0), src0->ne[0], main_stream); + } + } + else if (src1->type == GGML_TYPE_F32) { + auto to_fp32 = ggml_get_to_fp32_cuda(src0->type); + if (to_fp32) { + to_fp32(src0->data, (float *)src1->data, ggml_nrows(src0), src0->ne[0], main_stream); + } + } + else if (src1->type == GGML_TYPE_BF16) { + auto to_bf16 = ggml_get_to_bf16_cuda(src0->type); + if (to_bf16) { + to_bf16(src0->data, (nv_bfloat16 *)src1->data, ggml_nrows(src0), src0->ne[0], main_stream); + } + } + } else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) { + transpose_q8_0(ctx, src0, src1); } else { fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); + fprintf(stderr, "%s: %ld x %ld x %ld; %zu x %zu %zu -> %ld x %ld x %ld; %zu x %zu x %zu\n", __func__, + src0->ne[0], src0->ne[1], src0->ne[2], src0->nb[1], src0->nb[2], src0->nb[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->nb[1], src1->nb[2], src1->nb[3]); GGML_ABORT("fatal error"); } } @@ -559,9 +670,27 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { return (void*) cpy_f32_f16; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { return (void*) cpy_f32_f16; - } else { - fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, - ggml_type_name(src0->type), ggml_type_name(src1->type)); - GGML_ABORT("fatal error"); + } else if (ggml_is_contiguous(src0) && ggml_are_same_shape(src0, src1)) { + if (src1->type == GGML_TYPE_F16) { + auto to_fp16 = ggml_get_to_fp16_cuda(src0->type); + if (to_fp16) return (void*)to_fp16; + } + else if (src1->type == GGML_TYPE_F32) { + auto to_fp32 = ggml_get_to_fp32_cuda(src0->type); + if (to_fp32) return (void*)to_fp32; + } + else if (src1->type == GGML_TYPE_BF16) { + auto to_bf16 = ggml_get_to_bf16_cuda(src0->type); + if (to_bf16) return (void*)to_bf16; + } } + else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) { + return (void *)transpose_q8_0; + } + fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, + ggml_type_name(src0->type), ggml_type_name(src1->type)); + fprintf(stderr, "%s: %ld x %ld x %ld; %zu x %zu %zu -> %ld x %ld x %ld; %zu x %zu x %zu\n", __func__, + src0->ne[0], src0->ne[1], src0->ne[2], src0->nb[1], src0->nb[2], src0->nb[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->nb[1], src1->nb[2], src1->nb[3]); + GGML_ABORT("fatal error"); } From 305fabfc3b694d603fdb05d671dd59e2d4c7d58e Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 13 Mar 2025 12:07:43 +0200 Subject: [PATCH 038/132] FlashMLA-2 (CPU): faster and smaller compute buffer size (#253) * FlashMLA-2: eliminate intermediate f32 tensors This works on the CPU. PP performance is ~13% better for 16k tokens and compute buffer is quite a bit smaller. * FlashMLA-2: enable fast path only on the CPU for now I did implement the necessary ops on CUDA, but something is still wrong there, so for now we only use it when running CPU-only. * FlashMLA-2: slightly smaller computer buffer size --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-backend.c | 3 +- ggml/src/ggml.c | 123 +++++++++++++++++++++++++++++----- ggml/src/iqk/iqk_quantize.cpp | 28 ++++++++ ggml/src/iqk/iqk_quantize.h | 8 +++ src/llama.cpp | 111 +++++++++++++++++++++--------- 5 files changed, 225 insertions(+), 48 deletions(-) diff --git a/ggml/src/ggml-backend.c b/ggml/src/ggml-backend.c index 0458bd0c..fd538f50 100644 --- a/ggml/src/ggml-backend.c +++ b/ggml/src/ggml-backend.c @@ -843,7 +843,8 @@ GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const op->type != GGML_TYPE_IQ1_S && op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float case GGML_OP_MUL_MAT: - return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type; + return true; + //return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type; default: return true; } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 88820438..a904464e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -12589,6 +12589,43 @@ static void ggml_compute_forward_repeat_f16( } } +static void ggml_compute_forward_repeat_any( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src = dst->src[0]; + + GGML_ASSERT(ggml_can_repeat(src, dst)); + GGML_ASSERT(src->type == dst->type); + GGML_ASSERT(src->nb[0] == ggml_type_size(src->type)); + int64_t src_row_size = ggml_row_size(src->type, src->ne[0]); + GGML_ASSERT((int64_t )dst->nb[1] == src_row_size*dst->ne[0]/src->ne[0]); + + int ith = params->ith; + int nth = params->nth; + + int64_t nrows = ggml_nrows(dst); + int64_t nrows_per_thread = (nrows + nth - 1)/nth; + int64_t first_row = ith*nrows_per_thread; + if (first_row >= nrows) return; + int64_t last_row = MIN(first_row + nrows_per_thread, nrows); + + for (int64_t row = first_row; row < last_row; ++row) { + int64_t i3 = row/(dst->ne[1]*dst->ne[2]); + int64_t i2 = (row - i3*dst->ne[1]*dst->ne[2])/dst->ne[1]; + int64_t i1 = row - i3*dst->ne[1]*dst->ne[2] - i2*dst->ne[1]; + char * y = (char *)dst->data + i1*dst->nb[1] + i2*dst->nb[2] + i3*dst->nb[3]; + int64_t i03 = i3 % src->ne[3]; + int64_t i02 = i2 % src->ne[2]; + int64_t i01 = i1 % src->ne[1]; + const char * x = (const char *)src->data + i01*src->nb[1] + i02*src->nb[2] + i03*src->nb[3]; + for (int64_t ir = 0; ir < dst->ne[0]/src->ne[0]; ++ir) { + memcpy(y, x, src_row_size); + y += src_row_size; + } + } +} + static void ggml_compute_forward_repeat( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -12609,7 +12646,8 @@ static void ggml_compute_forward_repeat( } break; default: { - GGML_ABORT("fatal error"); + ggml_compute_forward_repeat_any(params, dst); + //GGML_ABORT("fatal error"); } } } @@ -12762,6 +12800,44 @@ static void ggml_compute_forward_concat_f32( } } +static void ggml_compute_forward_concat_any( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == src1->type && src0->type == dst->type); + + const int32_t dim = ggml_get_op_params_i32(dst, 0); + // Let's do it for dim = 0 only for now + GGML_ASSERT(dim == 0); + + int ith = params->ith; + int nth = params->nth; + + int64_t nrows = ggml_nrows(dst); + int64_t nrows_per_thread = (nrows + nth - 1)/nth; + int64_t first_row = ith*nrows_per_thread; + if (first_row >= nrows) return; + int64_t last_row = MIN(first_row + nrows_per_thread, nrows); + + int64_t src0_row_size = ggml_row_size(src0->type, src0->ne[0]); + int64_t src1_row_size = ggml_row_size(src1->type, src1->ne[0]); + + for (int64_t row = first_row; row < last_row; ++row) { + int64_t i3 = row/(dst->ne[1]*dst->ne[2]); + int64_t i2 = (row - i3*dst->ne[1]*dst->ne[2])/dst->ne[1]; + int64_t i1 = row - i3*dst->ne[1]*dst->ne[2] - i2*dst->ne[1]; + char * y = (char *)dst->data + i1*dst->nb[1] + i2*dst->nb[2] + i3*dst->nb[3]; + const char * x0 = (const char *)src0->data + i1*src0->nb[1] + i2*src0->nb[2] + i3*src0->nb[3]; + const char * x1 = (const char *)src1->data + i1*src1->nb[1] + i2*src1->nb[2] + i3*src1->nb[3]; + memcpy(y, x0, src0_row_size); + memcpy(y + src0_row_size, x1, src1_row_size); + } + +} + static void ggml_compute_forward_concat( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -12776,7 +12852,8 @@ static void ggml_compute_forward_concat( } break; default: { - GGML_ABORT("fatal error"); + ggml_compute_forward_concat_any(params, dst); + //GGML_ABORT("fatal error"); } } } @@ -14302,7 +14379,17 @@ UseGgmlGemm1:; const size_t nbw3 = nbw2*ne12; assert(params->wsize >= ne13*nbw3); - GGML_ASSERT(src1->type == GGML_TYPE_F32); + if (src1->type != GGML_TYPE_F32) { +#if GGML_USE_IQK_MULMAT + char * work_buffer = wdata + ne13*nbw3 + ith*ne10*sizeof(float); + GGML_ASSERT(params->wsize >= ne13*nbw3 + nth*ne10*sizeof(float)); + iqk_quantize_any(src1->type, vec_dot_type, ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, + src1->data, wdata, work_buffer, type_traits[src1->type].to_float, from_float, ith, nth); +#else + GGML_ABORT("fatal error"); +#endif + } + else { //#ifdef GGML_USE_IQK_MULMAT // int ts = type_traits[vec_dot_type].type_size; @@ -14348,6 +14435,7 @@ UseGgmlGemm1:; } } //#endif + } ggml_barrier(params->shared); @@ -16250,28 +16338,28 @@ static void ggml_compute_forward_soft_max_f32( } } -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(wp[i])); - } -#endif +//#ifndef NDEBUG +// for (int i = 0; i < nc; ++i) { +// //printf("p[%d] = %f\n", i, p[i]); +// assert(!isnan(wp[i])); +// } +//#endif float max = -INFINITY; ggml_vec_max_f32(nc, &max, wp); ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max); - assert(sum > 0.0); + //assert(sum > 0.0); sum = 1.0/sum; ggml_vec_scale_f32(nc, dp, sum); -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - assert(!isnan(dp[i])); - assert(!isinf(dp[i])); - } -#endif +//#ifndef NDEBUG +// for (int i = 0; i < nc; ++i) { +// assert(!isnan(dp[i])); +// assert(!isinf(dp[i])); +// } +//#endif } } @@ -21498,6 +21586,9 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa if (node->src[1]->type != vec_dot_type) { cur = ggml_row_size(vec_dot_type, node->src[1]->ne[0]) * ggml_nrows(node->src[1]); + if (node->src[1]->type != GGML_TYPE_F32) { + cur += n_tasks*node->src[1]->ne[0]*sizeof(float); // src1->type -> f32 -> vec_dot_type + } } } break; case GGML_OP_MUL_MAT_ID: diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index b61ae2db..fb6a5db4 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -185,6 +185,34 @@ void IQ1BNQuantizer::quantize_one_row_2bn(const float * src, block_iq2_bn * y, i } +void iqk_quantize_any(int from_type, int to_type, + int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, + uint64_t nb0, uint64_t nb1, uint64_t nb2, uint64_t nb3, + const void * x, void * y, void * work_buffer, + to_float_t to_float, from_float_t from_float, int ith, int nth) { + auto type_x = ggml_type(from_type); + GGML_ASSERT(ggml_type_size(type_x) == nb0); + auto type_y = ggml_type(to_type); + auto row_size_y = ggml_row_size(type_y, ne0); + int64_t nrows = ne1*ne2*ne3; + int64_t nrows_per_thread = (nrows + nth - 1)/nth; + int64_t first_row = nrows_per_thread*ith; + if (first_row >= nrows) return; + int64_t last_row = std::min(first_row + nrows_per_thread, nrows); + for (int64_t row = first_row; row < last_row; ++row) { + int64_t i3 = row/(ne1*ne2); + int64_t i2 = (row - i3*ne1*ne2)/ne1; + int64_t i1 = row - i3*ne1*ne2 - i2*ne1; + const char * cx = (const char *)x + i1*nb1 + i2*nb2 + i3*nb3; + // TODO: special case common types such as f16, q8_0 + // (although the performance gains may be too small to justify the added complexity) + to_float((const void *)cx, (float *)work_buffer, ne0); + auto cy = (char *)y + (i3*ne1*ne2 + i2*ne1 + i1)*row_size_y; + from_float((const float *)work_buffer, (void *)cy, ne0); + } +} + + size_t quantize_iq1_bn(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { IQ1BNQuantizer iq1bn; auto row_size = ggml_row_size(GGML_TYPE_IQ1_BN, n_per_row); diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 76fbac3b..d447705b 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -248,6 +248,14 @@ bool iqk_modify_tensor(struct ggml_tensor * tensor); // So we can re-pack Microsoft's BitNet I2_S quants void dequantize_row_ms_i2s(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +typedef void (*to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +typedef void (*from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void iqk_quantize_any(int from_type, int to_type, + int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, + uint64_t nb0, uint64_t nb1, uint64_t nb2, uint64_t nb3, + const void * GGML_RESTRICT x, void * GGML_RESTRICT y, void * work_buffer, + to_float_t to_float, from_float_t from_float, int ith, int nth); + #ifdef __cplusplus } #endif diff --git a/src/llama.cpp b/src/llama.cpp index ba5c5052..cc15cf33 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -13630,45 +13630,94 @@ struct llm_build_context { if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && (pp_opt || lctx.cparams.mla_attn > 2)) { - // Hahaha, we need to convert the KV cache for this layer to f32 because the general purpose ML library ggml does not - // provide ops on (almost) anything other than f32. In this case, the cache will be the second operand to a matrix - // multiplication, which *must* be f32. - auto kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_kv, kv_self.kv_l[il]->nb[1], 0); - auto kv_cache_view_f32 = ggml_cast(ctx0, kv_cache_view, GGML_TYPE_F32); - cb(kv_cache_view_f32, "kv_cache_view_f32", il); - // The no- and rotational position encoding portions of the KV cache - auto kv_cache_nope = ggml_view_2d(ctx0, kv_cache_view_f32, kv_lora_rank, n_kv, kv_cache_view_f32->nb[1], 0); - auto kv_cache_rope = ggml_view_3d(ctx0, kv_cache_view_f32, n_embd_head_qk_rope, 1, n_kv, - kv_cache_view_f32->nb[1], kv_cache_view_f32->nb[1], ggml_row_size(kv_cache_view_f32->type, kv_lora_rank)); + ggml_tensor * k; + ggml_tensor * v; - auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope); - cb(kv_f32, "kv_f32", il); + // For now this only works in the CPU implementation, so we only use it if there is just the CPU backend. + // If the code was compiled with CUDA (and/or Metal, Vulkan, whatever) support, this branch will not + // be taken even if no layers were offloaded to the GPU. + if (lctx.backends.size() == 1 && lctx.backends.front() == lctx.backend_cpu) { - auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head, - ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0); - cb(k_nope_f32, "k_nope_f32", il); + auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], 0); - ggml_tensor repeater; - repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_head; repeater.ne[2] = n_kv; repeater.ne[3] = 1; - auto k_rope_f32 = ggml_permute(ctx0, ggml_repeat(ctx0, kv_cache_rope, &repeater), 0, 2, 1, 3); - cb(k_rope_f32, "k_rope_f32", il); + auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope); + cb(kv_f32, "kv_f32", il); - auto k_f32 = ggml_concat(ctx0, k_nope_f32, k_rope_f32, 0); - cb(k_f32, "k_f32", il); + auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head, + ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope)); + cb(v_f32, "v_f32", il); - auto k = ggml_cast(ctx0, k_f32, kv_self.kv_l[il]->type); - cb(k, "k", il); + v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type); + cb(v, "v", il); - auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head, - ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - ggml_row_size(kv_f32->type, n_embd_head_qk_nope)); - cb(v_f32, "v_f32", il); + auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head, + ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0); + cb(k_nope_f32, "k_nope_f32", il); - auto v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type); - cb(v, "v", il); + auto k_nope = ggml_cast(ctx0, k_nope_f32, kv_self.kv_l[il]->type); + cb(k_nope, "k_nope", il); + + ggml_build_forward_expand(gf, k_nope); + ggml_build_forward_expand(gf, v); + + auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, 1, + kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)); + + ggml_tensor repeater; + repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_head; repeater.ne[3] = 1; + auto k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater); + cb(k_rope, "k_rope", il); + + k = ggml_concat(ctx0, k_nope, k_rope, 0); + cb(k, "k", il); + + ggml_build_forward_expand(gf, k); + } + else { + // Hahaha, we need to convert the KV cache for this layer to f32 because the general purpose ML library ggml does not + // provide ops on (almost) anything other than f32. In this case, the cache will be the second operand to a matrix + // multiplication, which *must* be f32. + auto kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_kv, kv_self.kv_l[il]->nb[1], 0); + auto kv_cache_view_f32 = ggml_cast(ctx0, kv_cache_view, GGML_TYPE_F32); + cb(kv_cache_view_f32, "kv_cache_view_f32", il); + + // The no- and rotational position encoding portions of the KV cache + auto kv_cache_nope = ggml_view_2d(ctx0, kv_cache_view_f32, kv_lora_rank, n_kv, kv_cache_view_f32->nb[1], 0); + auto kv_cache_rope = ggml_view_3d(ctx0, kv_cache_view_f32, n_embd_head_qk_rope, 1, n_kv, + kv_cache_view_f32->nb[1], kv_cache_view_f32->nb[1], ggml_row_size(kv_cache_view_f32->type, kv_lora_rank)); + + auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope); + cb(kv_f32, "kv_f32", il); + + auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head, + ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0); + cb(k_nope_f32, "k_nope_f32", il); + + ggml_tensor repeater; + repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_head; repeater.ne[2] = n_kv; repeater.ne[3] = 1; + auto k_rope_f32 = ggml_permute(ctx0, ggml_repeat(ctx0, kv_cache_rope, &repeater), 0, 2, 1, 3); + cb(k_rope_f32, "k_rope_f32", il); + + auto k_f32 = ggml_concat(ctx0, k_nope_f32, k_rope_f32, 0); + cb(k_f32, "k_f32", il); + + k = ggml_cast(ctx0, k_f32, kv_self.kv_l[il]->type); + cb(k, "k", il); + + auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head, + ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + ggml_row_size(kv_f32->type, n_embd_head_qk_nope)); + cb(v_f32, "v_f32", il); + + v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type); + cb(v, "v", il); + } auto q = ggml_concat(ctx0, q_nope, q_rope, 0); q = ggml_permute(ctx0, q, 0, 2, 1, 3); From f91b2e38d028c77cc5631295ba0937749e684749 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 17 Mar 2025 09:31:56 +0100 Subject: [PATCH 039/132] Prepare wk_b tensors of DeepSeek models on the fly (#259) * FlashMLA-2: eliminate intermediate f32 tensors This works on the CPU. PP performance is ~13% better for 16k tokens and compute buffer is quite a bit smaller. * FlashMLA-2: enable fast path only on the CPU for now I did implement the necessary ops on CUDA, but something is still wrong there, so for now we only use it when running CPU-only. * FlashMLA-2: slightly smaller computer buffer size * Prepare wk_b when loading DeepSeek models (if wk_b is missing) * Add some comments * Fix case where wkv_b is quantized with k- or i-quants. * Fix CUDA There is an issue with quantized GEMV on CUDA when the left operand (the matrix) is not contiguous. So, for now, we also create wv_b during model loading and use that instead of the 3D view of wkv_b. --------- Co-authored-by: Iwan Kawrakow --- src/llama.cpp | 173 ++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 152 insertions(+), 21 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index cc15cf33..34934a15 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2640,6 +2640,9 @@ struct llama_layer { struct ggml_tensor * ffn_gate_scale; struct ggml_tensor * ffn_up_scale; struct ggml_tensor * ffn_down_scale; + + std::unique_ptr computed_wk_b; + std::unique_ptr computed_wv_b; }; struct llama_kv_cell { @@ -3186,17 +3189,6 @@ static bool llama_kv_cache_init( ggml_tensor * k; ggml_tensor * v; if (cparams.mla_attn) { - if (!model.layers[i].wk_b || !model.layers[i].wv_b) { - if (warn) { - LLAMA_LOG_WARN("=======================================================================================\n"); - LLAMA_LOG_WARN("%s: missing MLA tensors => disabling MLA\n", __func__); - LLAMA_LOG_WARN("%s: you need to reconvert your model in order to use MLA\n", __func__); - LLAMA_LOG_WARN("=======================================================================================\n"); - warn = false; - } - } - } - if (cparams.mla_attn && model.layers[i].wk_b && model.layers[i].wv_b) { // DeepSeek MLA const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; @@ -8130,6 +8122,139 @@ static bool llm_load_tensors( } } + if (model.arch == LLM_ARCH_DEEPSEEK2) { + int n_to_compute = 0; + for (auto& l : model.layers) { + if (!l.wk_b) ++n_to_compute; + } + if (n_to_compute > 0) { + // Prepare wk_b tensors to enable MLA usage also for model files that do not include + // the wk_b tensors (because, e.g., they were converted using mainline llama.cpp) + // We do it here because otherwise wkv_b may get run-time-repacked, which will make + // preparation of wk_b impossible. It also has the benefit that wk_b will get automatically + // run-time repacked if the rtr option is set. The downside is that we will prepare wk_b + // even if it is not needed (because MLA is not being used). If we wanted to avoid + // computing wk_b from wkv_b if not needed, we would need to propagate the context parameters + // to the model loading function. On the other hand, in some hypothetical bright future, + // where we are able to use the optimum settings for the computation, which for DeepSeekV3/R1/Lite + // is no MLA + FA for prompt processing, and MLA + FA for token generation, it would be useful + // to change the MLA setting on the fly, depending on context. In that case, having prepared + // the MLA tensors here is the right ting to do^TM. + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + const int32_t n_embd_head_v = hparams.n_embd_head_v; + const int32_t n_head = hparams.n_head(0); + std::vector work_data; + LLAMA_LOG_INFO("============ %s: need to compute %d wk_b tensors\n", __func__, n_to_compute); + for (int il = 1; il < n_layer; ++il) { + // Somehow the number of heads is being defined as being per layer. Not sure why this is the + // case, but for now we do not support strange models that have different numbers of heads + // in different model layers. + if (hparams.n_head(il) != n_head) throw std::runtime_error("Unsupported configuration"); + } + auto total_size_wkb = 0; + size_t max_wkv_size = 0; + size_t max_wk_size = 0; + for (auto& l : model.layers) { + if (!l.wk_b) { + auto new_type = ggml_is_quantized(l.wkv_b->type) ? GGML_TYPE_Q8_0 : l.wkv_b->type; + auto size = ggml_row_size(new_type, n_embd_head_qk_nope)*kv_lora_rank*n_head; + max_wk_size = std::max(max_wk_size, size); + if (!ggml_backend_buffer_is_host(l.wkv_b->buffer)) { + max_wkv_size = std::max(max_wkv_size, ggml_nbytes(l.wkv_b)); + } + } + } + auto context_size = max_wk_size + 2*n_embd_head_qk_nope*kv_lora_rank*n_head*sizeof(float); + context_size *= 2; // just in case; + std::vector wkv_buffer; + if (max_wkv_size > 0) wkv_buffer.resize(max_wkv_size); + // So, transposing tensors and then making them contiguous as needed for wk_b may or may not + // be supported on all backends. Hence, to be sure that the preparation of wk_b will + // work correctly, we do it on the CPU backend. We then copy the resulting tensor data to + // the bacikend where wkv_b is stored. + ggml_init_params params{context_size, nullptr, true}; + auto ctx = ggml_init(params); + auto graph = ggml_new_graph_custom(ctx, 8, false); + std::vector tensor_data(2*n_embd_head_qk_nope*kv_lora_rank*n_head*sizeof(float) + max_wk_size); + for (int il = 0; il < n_layer; ++il) { + auto& l = model.layers[il]; + if (l.wk_b) continue; + auto wkv_b = *l.wkv_b; + if (!ggml_backend_buffer_is_host(l.wkv_b->buffer)) { + ggml_backend_tensor_get(l.wkv_b, wkv_buffer.data(), 0, ggml_nbytes(l.wkv_b)); + wkv_b.data = wkv_buffer.data(); + } + auto wk_b_view = ggml_view_3d(ctx, &wkv_b, kv_lora_rank, n_embd_head_qk_nope, n_head, + l.wkv_b->nb[1], l.wkv_b->nb[1]*(n_embd_head_qk_nope + n_embd_head_v), 0); + auto wk_b_f32 = ggml_cast(ctx, wk_b_view, GGML_TYPE_F32); + wk_b_f32->data = tensor_data.data(); + auto wk_b_f32_tview = ggml_transpose(ctx, wk_b_f32); + auto wk_b_f32_t = ggml_cont(ctx, wk_b_f32_tview); + wk_b_f32_t->data = (char *)wk_b_f32->data + ggml_nbytes(wk_b_f32); + + auto new_type = ggml_is_quantized(wkv_b.type) ? GGML_TYPE_Q8_0 : wkv_b.type; + auto wk_b = ggml_cast(ctx, wk_b_f32_t, new_type); + wk_b->data = (char *)wk_b_f32_t->data + ggml_nbytes(wk_b_f32_t); + + ggml_build_forward_expand(graph, wk_b); + + auto plan = ggml_graph_plan(graph, std::thread::hardware_concurrency()/2); + if (plan.work_size > work_data.size()) work_data.resize(plan.work_size); + plan.work_data = work_data.data(); + + auto status = ggml_graph_compute(graph, &plan); + if (status != GGML_STATUS_SUCCESS) throw std::runtime_error("Failed to compute wk_b"); + + auto name = std::string{"blk."} + std::to_string(il) + ".attn_k_b.weight"; + + l.computed_wk_b = std::make_unique(*wk_b); + l.computed_wk_b->buffer = ggml_backend_buft_alloc_buffer(ggml_backend_buffer_get_type(l.wkv_b->buffer), ggml_nbytes(wk_b)); + l.computed_wk_b->data = ggml_backend_buffer_get_base(l.computed_wk_b->buffer); + l.computed_wk_b->op = GGML_OP_NONE; // we absolutely need to do this, else the backend will attempt to find the parents + // of wk_b, which no longer exist, and will therefore crash. + for (int j = 0; j < GGML_MAX_SRC; ++j) l.computed_wk_b->src[j] = nullptr; + ggml_set_name(l.computed_wk_b.get(), name.c_str()); + ggml_backend_buffer_set_usage(l.computed_wk_b->buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + ggml_backend_tensor_set(l.computed_wk_b.get(), wk_b->data, 0, ggml_nbytes(wk_b)); + + l.wk_b = l.computed_wk_b.get(); + + ggml_graph_clear(graph); + auto wv_b = ggml_cont(ctx, ggml_view_3d(ctx, &wkv_b, kv_lora_rank, n_embd_head_v, n_head, + l.wkv_b->nb[1], l.wkv_b->nb[1]*(n_embd_head_qk_nope + n_embd_head_v), l.wkv_b->nb[1]*n_embd_head_qk_nope)); + wv_b->data = tensor_data.data(); + ggml_build_forward_expand(graph, wv_b); + plan = ggml_graph_plan(graph, std::thread::hardware_concurrency()/2); + if (plan.work_size > work_data.size()) work_data.resize(plan.work_size); + plan.work_data = work_data.data(); + status = ggml_graph_compute(graph, &plan); + if (status != GGML_STATUS_SUCCESS) throw std::runtime_error("Failed to compute wv_b"); + + name = std::string{"blk."} + std::to_string(il) + ".attn_v_b.weight"; + + l.computed_wv_b = std::make_unique(*wv_b); + l.computed_wv_b->buffer = ggml_backend_buft_alloc_buffer(ggml_backend_buffer_get_type(l.wkv_b->buffer), ggml_nbytes(wv_b)); + l.computed_wv_b->data = ggml_backend_buffer_get_base(l.computed_wv_b->buffer); + l.computed_wv_b->op = GGML_OP_NONE; // we absolutely need to do this, else the backend will attempt to find the parents + // of wk_b, which no longer exist, and will therefore crash. + for (int j = 0; j < GGML_MAX_SRC; ++j) l.computed_wv_b->src[j] = nullptr; + ggml_set_name(l.computed_wv_b.get(), name.c_str()); + ggml_backend_buffer_set_usage(l.computed_wv_b->buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + ggml_backend_tensor_set(l.computed_wv_b.get(), wv_b->data, 0, ggml_nbytes(wv_b)); + + l.wv_b = l.computed_wv_b.get(); + + printf("Computed %s as %ld x %ld x %ld and stored in buffer %s\n", name.c_str(), wk_b->ne[0], wk_b->ne[1], wk_b->ne[2], + ggml_backend_buffer_name(l.computed_wk_b->buffer)); + + ggml_graph_clear(graph); + } + ggml_free(ctx); + } + } + if (use_mmap_buffer) { for (auto & mapping : ml.mappings) { model.mappings.emplace_back(std::move(mapping)); @@ -13595,7 +13720,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(kv_compressed, "kv_compressed", il); - if (lctx.cparams.mla_attn && model.layers[il].wk_b && model.layers[il].wv_b) { + if (lctx.cparams.mla_attn) { ggml_tensor * kv_cache_trans; @@ -13738,10 +13863,9 @@ struct llm_build_context { ggml_tensor * kqv_compressed; - struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, - ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), - ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank)*n_embd_head_qk_nope, 0); - cb(wk_b, "wk_b", il); + auto wkv_b = model.layers[il].wkv_b; + auto wk_b = model.layers[il].wk_b->ne[1] == kv_lora_rank ? model.layers[il].wk_b + : ggml_reshape_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head); q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); cb(q_nope, "q_nope_perm", il); @@ -13832,11 +13956,18 @@ struct llm_build_context { } } - struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, - ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), - ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank)*n_embd_head_v, 0); - cb(wv_b, "wv_b", il); - std::memcpy(wv_b->name, model.layers[il].wv_b->name, GGML_MAX_NAME); + auto wv_b = model.layers[il].wv_b; + if (wv_b->ne[1] != n_embd_head_v) { + wv_b = ggml_reshape_3d(ctx0, wv_b, kv_lora_rank, n_embd_head_v, n_head); + cb(wv_b, "wv_b", il); + } + // There is an issue with quantized GEMV on CUDA when the left operand (the matrix) is + // not contiguous. So, for now, we create wv_b during model loading and use that + // instead of the commented out 3D view below. + //auto wv_b = ggml_view_3d(ctx0, wkv_b, kv_lora_rank, n_embd_head_v, n_head, + // wkv_b->nb[1], wkv_b->nb[1]*(n_embd_head_v + n_embd_head_qk_nope), + // wkv_b->nb[1]*n_embd_head_qk_nope); + //cb(wv_b, "wv_b", il); kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed); cb(kqv, "kqv", il); From dcdfad29f7d2b831f1c84751f00bda14cc359a84 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 18 Mar 2025 07:36:42 +0100 Subject: [PATCH 040/132] FlashMLA-2: reduce compute buffer size (CUDA and CPU) (#260) * FlashMLA-2: eliminate intermediate f32 tensors This works on the CPU. PP performance is ~13% better for 16k tokens and compute buffer is quite a bit smaller. * FlashMLA-2: enable fast path only on the CPU for now I did implement the necessary ops on CUDA, but something is still wrong there, so for now we only use it when running CPU-only. * FlashMLA-2: slightly smaller computer buffer size * Prepare wk_b when loading DeepSeek models (if wk_b is missing) * Add some comments * Fix case where wkv_b is quantized with k- or i-quants. * Fix CUDA There is an issue with quantized GEMV on CUDA when the left operand (the matrix) is not contiguous. So, for now, we also create wv_b during model loading and use that instead of the 3D view of wkv_b. * FlashMLA-2: avoid conversions to f32 also on CUDA * Be able to compute for more than 65535 tokens On CUDA just a quick hack that allows us to cancatenate tensors with more than 65535 rows along zroth dimension as needed by FlashMLA-2. Also needed some care in the perplexity tool to avoid int overflows when evaluating the computed logits. * Reduce memory usage for FlashMLA-2 Oh, also fix int overflow in the CUDA concat implementation. It is funny how the llama.cpp 64-bit police has gone (almost) everywhere and replaced 32-bit ints with 64-bit ints, needed or not, but hasn't done it where it is actually needed. --------- Co-authored-by: Iwan Kawrakow --- examples/perplexity/perplexity.cpp | 6 +- ggml/src/ggml-cuda.cu | 2 +- ggml/src/ggml-cuda/binbcast.cu | 38 +++++--- ggml/src/ggml-cuda/concat.cu | 135 +++++++++++++++++++++-------- src/llama.cpp | 128 +++++++++++---------------- 5 files changed, 184 insertions(+), 125 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 372684f0..95aedce6 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -166,7 +166,7 @@ static void process_logits( break; } lock.unlock(); - const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]); + const results_log_softmax results = log_softmax(n_vocab, logits + int64_t(i)*n_vocab, tokens[i+1]); const double v = -results.log_softmax; local_nll += v; local_nll2 += v*v; @@ -200,7 +200,7 @@ static void process_logits(std::ostream& out, int n_vocab, const float * logits, break; } lock.unlock(); - const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + i*nv, tokens[i+1]); + const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + int64_t(i)*nv, tokens[i+1]); local_nll += v; local_nll2 += v*v; } @@ -618,7 +618,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par if (num_batches > 1 && n_outputs > 0) { const auto * batch_logits = llama_get_logits(ctx); - logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab); + logits.insert(logits.end(), batch_logits, batch_logits + int64_t(n_outputs) * n_vocab); } } diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 1bb869c3..58a44cf7 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3354,7 +3354,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons if (op->op == GGML_OP_MOE_FUSED_UP_GATE && a->type != op->src[1]->type) { return false; } - if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) { + if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16 && !ggml_is_quantized(a->type)) { return false; } if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) { diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 5abbd43c..a2508350 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -248,17 +248,35 @@ static void ggml_cuda_op_bin_bcast( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) { - GGML_ASSERT(src1->type == GGML_TYPE_F32); + //GGML_ASSERT(src1->type == GGML_TYPE_F32); - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream); - } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { - op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream); - } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { - op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream); - } else { - fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, - ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); + if (src1->type == GGML_TYPE_F32) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); + GGML_ABORT("fatal error"); + } + } + else if (src1->type == GGML_TYPE_F16) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + op()(src0, src1, dst, (const float *)src0_dd, (const half *)src1_dd, (float *)dst_dd, stream); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (half *) dst_dd, stream); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (float *)dst_dd, stream); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); + GGML_ABORT("fatal error"); + } + } + else { GGML_ABORT("fatal error"); } } diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index 4bde6d69..b40617f6 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -1,7 +1,7 @@ #include "concat.cuh" // contiguous kernels -static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) { +static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne00) { int nidx = threadIdx.x + blockIdx.x * blockDim.x; if (nidx >= ne0) { return; @@ -27,7 +27,35 @@ static __global__ void concat_f32_dim0(const float * x, const float * y, float * } } -static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) { +// contiguous kernels +static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne00, + int64_t nb02, int64_t nb12, int64_t nb2) { + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + if (nidx >= ne0) { + return; + } + + int offset_dst = + nidx + + blockIdx.y * ne0 + + blockIdx.z * nb2; + + if (nidx < ne00) { // src0 + int offset_src = + nidx + + blockIdx.y * ne00 + + blockIdx.z * nb02; + dst[offset_dst] = x[offset_src]; + } else { + int offset_src = + (nidx - ne00) + + blockIdx.y * (ne0 - ne00) + + blockIdx.z * nb12; + dst[offset_dst] = y[offset_src]; + } +} + +static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne01) { int nidx = threadIdx.x + blockIdx.x * blockDim.x; if (nidx >= ne0) { return; @@ -53,7 +81,7 @@ static __global__ void concat_f32_dim1(const float * x, const float * y, float * } } -static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) { +static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int64_t ne0, const int64_t ne02) { int nidx = threadIdx.x + blockIdx.x * blockDim.x; if (nidx >= ne0) { return; @@ -81,9 +109,23 @@ static __global__ void concat_f32_dim2(const float * x, const float * y, float * static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) { int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; + if (dim == 0 && ne1 >= 65536) { + int64_t nstep = (ne1 + 32767)/32768; + for (int64_t istep = 0; istep < nstep; ++istep) { + int64_t i1 = 32768*istep; + int64_t n1 = i1 + 32768 <= ne1 ? 32768 : ne1 - i1; + dim3 gridDim(num_blocks, n1, ne2); + const float * xi = x + i1*ne00; + const float * yi = y + i1*(ne0 - ne00); + float * dst_i = dst + i1*ne0; + concat_f32_dim0<<>>(xi, yi, dst_i, ne0, ne00, ne00*ne01, (ne0-ne00)*ne01, ne0*ne1); + } + return; + } dim3 gridDim(num_blocks, ne1, ne2); if (dim == 0) { concat_f32_dim0<<>>(x, y, dst, ne0, ne00); + //concat_f32_dim0<<>>(x, y, dst, ne0, ne00, ne00*ne01, (ne0-ne00)*ne01, ne0*ne1); return; } if (dim == 1) { @@ -150,52 +192,77 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; + GGML_ASSERT(src0->type == src1->type && src0->type == dst->type); + cudaStream_t stream = ctx.stream(); const int32_t dim = ((int32_t *) dst->op_params)[0]; + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && + (dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1))) { + const size_t size0 = ggml_nbytes(src0); + const size_t size1 = ggml_nbytes(src1); + CUDA_CHECK(cudaMemcpyAsync((char *)dst->data, src0->data, size0, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync((char *)dst->data + size0, src1->data, size1, cudaMemcpyDeviceToDevice, stream)); + return; + } + + if (dim == 0 && src0->nb[0] == ggml_type_size(src0->type) && src1->nb[0] == ggml_type_size(src1->type) && + src0->nb[1] % sizeof(float) == 0 && src1->nb[1] % sizeof(float) == 0) { + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + //if (dst->ne[1] >= 65536 || dst->ne[2] >= 65536) { + // fprintf(stderr, "%s: ne1 = %ld, ne2 = %ld exceed max. blocks when computing %s\n", __func__, dst->ne[1], dst->ne[2], dst->name); + // GGML_ABORT("fatal error"); + //} + const float * src0_d = (const float *)src0->data; + const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; + for (int i3 = 0; i3 < dst->ne[3]; i3++) { + concat_f32_cuda( + src0_d + i3 * (src0->nb[3] / 4), + src1_d + i3 * (src1->nb[3] / 4), + dst_d + i3 * ( dst->nb[3] / 4), + src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2], + dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dim, stream); + } + } else { + dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]); + concat_f32_non_cont<<>>( + (const char *)src0->data, + (const char *)src1->data, + ( char *)dst->data, + src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2], src0->ne[3], + sizeof(float), src0->nb[1], src0->nb[2], src0->nb[3], + src1->ne[0]*src1->nb[0]/sizeof(float), src1->ne[1], src1->ne[2], src1->ne[3], + sizeof(float), src1->nb[1], src1->nb[2], src1->nb[3], + dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dst->ne[3], + sizeof(float), dst->nb[1], dst->nb[2], dst->nb[3], dim); + } + return; + } + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + //if (dst->ne[1] >= 65536 || dst->ne[2] >= 65536) { + // fprintf(stderr, "%s: ne1 = %ld, ne2 = %ld exceed max. blocks when computing %s\n", __func__, dst->ne[1], dst->ne[2], dst->name); + // GGML_ABORT("fatal error"); + //} const float * src0_d = (const float *)src0->data; const float * src1_d = (const float *)src1->data; float * dst_d = (float *)dst->data; - if (dim == 3 || (dim == 2 && dst->ne[3] == 1) || (dim == 1 && dst->ne[2]*dst->ne[3] == 1)) { - const size_t size0 = ggml_nbytes(src0); - const size_t size1 = ggml_nbytes(src1); - CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream)); - CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream)); - } else { - for (int i3 = 0; i3 < dst->ne[3]; i3++) { - concat_f32_cuda( - src0_d + i3 * (src0->nb[3] / 4), - src1_d + i3 * (src1->nb[3] / 4), - dst_d + i3 * ( dst->nb[3] / 4), - src0->ne[0], src0->ne[1], src0->ne[2], - dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); - } + for (int i3 = 0; i3 < dst->ne[3]; i3++) { + concat_f32_cuda( + src0_d + i3 * (src0->nb[3] / 4), + src1_d + i3 * (src1->nb[3] / 4), + dst_d + i3 * ( dst->nb[3] / 4), + src0->ne[0], src0->ne[1], src0->ne[2], + dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); } - - //if (dim != 3) { - // for (int i3 = 0; i3 < dst->ne[3]; i3++) { - // concat_f32_cuda( - // src0_d + i3 * (src0->nb[3] / 4), - // src1_d + i3 * (src1->nb[3] / 4), - // dst_d + i3 * ( dst->nb[3] / 4), - // src0->ne[0], src0->ne[1], src0->ne[2], - // dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); - // } - //} else { - // const size_t size0 = ggml_nbytes(src0); - // const size_t size1 = ggml_nbytes(src1); - - // CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream)); - // CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream)); - //} } else { dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]); concat_f32_non_cont<<>>( diff --git a/src/llama.cpp b/src/llama.cpp index 34934a15..605e5d36 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -13755,31 +13755,52 @@ struct llm_build_context { if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && (pp_opt || lctx.cparams.mla_attn > 2)) { + auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], 0); - ggml_tensor * k; - ggml_tensor * v; + auto kv_f32_size = model.layers[il].wkv_b->ne[1] * kv_cache_nope->ne[1] * sizeof(float) / (1024*1024); + int n_max_head = n_head; + if (cparams.attn_max_batch > 0 && kv_f32_size > cparams.attn_max_batch) { + while (n_max_head%2 == 0 && kv_f32_size > cparams.attn_max_batch) { + n_max_head /= 2; kv_f32_size /= 2; + } + } + GGML_ASSERT(n_head % n_max_head == 0); - // For now this only works in the CPU implementation, so we only use it if there is just the CPU backend. - // If the code was compiled with CUDA (and/or Metal, Vulkan, whatever) support, this branch will not - // be taken even if no layers were offloaded to the GPU. - if (lctx.backends.size() == 1 && lctx.backends.front() == lctx.backend_cpu) { + auto n_per_head = model.layers[il].wkv_b->ne[1] / n_head; - auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], 0); + auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, 1, + kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)); - auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope); + ggml_tensor repeater; + repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_max_head; repeater.ne[3] = 1; + auto k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater); + cb(k_rope, "k_rope", il); + + auto q = ggml_concat(ctx0, q_nope, q_rope, 0); + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + cb(q, "q_concat", il); + + ggml_build_forward_expand(gf, q); + + for (int iter = 0; iter < n_head/n_max_head; ++iter) { + + auto wkv_b = ggml_view_2d(ctx0, model.layers[il].wkv_b, model.layers[il].wkv_b->ne[0], n_per_head*n_max_head, + model.layers[il].wkv_b->nb[1], model.layers[il].wkv_b->nb[1]*n_per_head*n_max_head*iter); + + auto kv_f32 = ggml_mul_mat(ctx0, wkv_b, kv_cache_nope); cb(kv_f32, "kv_f32", il); - auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head, - ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_max_head, + ggml_row_size(kv_f32->type, n_max_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_row_size(kv_f32->type, n_embd_head_qk_nope)); cb(v_f32, "v_f32", il); - v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type); + auto v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type); cb(v, "v", il); - auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head, - ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_max_head, + ggml_row_size(kv_f32->type, n_max_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0); cb(k_nope_f32, "k_nope_f32", il); @@ -13789,74 +13810,27 @@ struct llm_build_context { ggml_build_forward_expand(gf, k_nope); ggml_build_forward_expand(gf, v); - auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, 1, - kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)); - - ggml_tensor repeater; - repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_head; repeater.ne[3] = 1; - auto k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater); - cb(k_rope, "k_rope", il); - - k = ggml_concat(ctx0, k_nope, k_rope, 0); + auto k = ggml_concat(ctx0, k_nope, k_rope, 0); cb(k, "k", il); ggml_build_forward_expand(gf, k); + + auto q_iter = ggml_view_3d(ctx0, q, q->ne[0], q->ne[1], n_max_head, + q->nb[1], q->nb[2], q->nb[2]*n_max_head*iter); + + kqv = ggml_flash_attn_ext(ctx0, q_iter, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f); + if (q->ne[1] <= 8) { + ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32); + } + cb(kqv, "kqv", il); + + if (iter == 0) { + cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_max_head, n_tokens); + } else { + cur = ggml_concat(ctx0, cur, ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_max_head, n_tokens), 0); + } + } - else { - // Hahaha, we need to convert the KV cache for this layer to f32 because the general purpose ML library ggml does not - // provide ops on (almost) anything other than f32. In this case, the cache will be the second operand to a matrix - // multiplication, which *must* be f32. - auto kv_cache_view = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_self.kv_l[il]->ne[0], n_kv, kv_self.kv_l[il]->nb[1], 0); - auto kv_cache_view_f32 = ggml_cast(ctx0, kv_cache_view, GGML_TYPE_F32); - cb(kv_cache_view_f32, "kv_cache_view_f32", il); - - // The no- and rotational position encoding portions of the KV cache - auto kv_cache_nope = ggml_view_2d(ctx0, kv_cache_view_f32, kv_lora_rank, n_kv, kv_cache_view_f32->nb[1], 0); - auto kv_cache_rope = ggml_view_3d(ctx0, kv_cache_view_f32, n_embd_head_qk_rope, 1, n_kv, - kv_cache_view_f32->nb[1], kv_cache_view_f32->nb[1], ggml_row_size(kv_cache_view_f32->type, kv_lora_rank)); - - auto kv_f32 = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache_nope); - cb(kv_f32, "kv_f32", il); - - auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_head, - ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0); - cb(k_nope_f32, "k_nope_f32", il); - - ggml_tensor repeater; - repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_head; repeater.ne[2] = n_kv; repeater.ne[3] = 1; - auto k_rope_f32 = ggml_permute(ctx0, ggml_repeat(ctx0, kv_cache_rope, &repeater), 0, 2, 1, 3); - cb(k_rope_f32, "k_rope_f32", il); - - auto k_f32 = ggml_concat(ctx0, k_nope_f32, k_rope_f32, 0); - cb(k_f32, "k_f32", il); - - k = ggml_cast(ctx0, k_f32, kv_self.kv_l[il]->type); - cb(k, "k", il); - - auto v_f32 = ggml_view_3d(ctx0, kv_f32, hparams.n_embd_head_v, n_kv, n_head, - ggml_row_size(kv_f32->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - ggml_row_size(kv_f32->type, n_embd_head_qk_nope)); - cb(v_f32, "v_f32", il); - - v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type); - cb(v, "v", il); - } - - auto q = ggml_concat(ctx0, q_nope, q_rope, 0); - q = ggml_permute(ctx0, q, 0, 2, 1, 3); - cb(q, "q_concat", il); - - ggml_build_forward_expand(gf, q); - - kqv = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, kq_scale, hparams.f_max_alibi_bias, 0.f); - if (q->ne[1] <= 8) { - ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32); - } - cb(kqv, "kqv", il); - - cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens); } else { From bdcae905c4cb0de1025a45a2bd6c2e646cc22be7 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 18 Mar 2025 07:37:10 +0100 Subject: [PATCH 041/132] Compile time option to use bf16 for qunts without MMQ kernels (#261) Co-authored-by: Iwan Kawrakow --- ggml/CMakeLists.txt | 1 + ggml/src/CMakeLists.txt | 4 + ggml/src/ggml-cuda.cu | 40 +++++++++ ggml/src/ggml-cuda/convert.cu | 152 ++++++++++++++++++++++++++-------- 4 files changed, 162 insertions(+), 35 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 6775fdcb..70e3bbf3 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -118,6 +118,7 @@ option(GGML_MUSA "ggml: use MUSA" option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF) option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF) option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF) +option(GGML_CUDA_IQK_FORCE_BF16 "ggml: use bf16 cuBLAS when no MMQ kernel is available" OFF) set (GGML_CUDA_DMMV_X "32" CACHE STRING "ggml: x stride for dmmv CUDA kernels") set (GGML_CUDA_MMV_Y "1" CACHE STRING "ggml: y block size for mmv CUDA kernels") option(GGML_CUDA_F16 "ggml: use 16 bit floats for some calculations" OFF) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index c1ebd870..7013526c 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -362,6 +362,10 @@ if (GGML_CUDA) add_compile_definitions(GGML_CUDA_FORCE_MMQ) endif() + if (GGML_CUDA_IQK_FORCE_BF16) + add_compile_definitions(GGML_CUDA_IQK_FORCE_BF16) + endif() + if (GGML_CUDA_FORCE_CUBLAS) add_compile_definitions(GGML_CUDA_FORCE_CUBLAS) endif() diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 58a44cf7..a8f7383a 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -1264,6 +1264,46 @@ static void ggml_cuda_op_mul_mat_cublas( return; } +#ifdef GGML_CUDA_IQK_FORCE_BF16 + if (ggml_is_quantized(src0->type) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) { + to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src0->type); + GGML_ASSERT(to_bf16_cuda != nullptr); + size_t ne = row_diff*ne00; + ggml_cuda_pool_alloc src0_as_bf16(ctx.pool(id), ne); + to_bf16_cuda(src0_dd_i, src0_as_bf16.get(), row_diff, ne00, stream); + + ggml_cuda_pool_alloc src1_as_bf16(ctx.pool(id)); + if (src1->type != GGML_TYPE_BF16) { + size_t ne = src1_ncols*ne10; + src1_as_bf16.alloc(ne); + to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type); + GGML_ASSERT(to_bf16_cuda != nullptr); + to_bf16_cuda(src1_ddf_i, src1_as_bf16.get(), src1_ncols, ne10, stream); + } + const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get(); + const nv_bfloat16 * src0_ptr = src0_as_bf16.get(); + + ggml_cuda_pool_alloc dst_bf16(ctx.pool(id), row_diff*src1_ncols); + + const float alpha_f32 = 1.0f; + const float beta_f32 = 0.0f; + + CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); + CUBLAS_CHECK( + cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha_f32, src0_ptr, CUDA_R_16BF, ne00, + src1_ptr, CUDA_R_16BF, ne10, + &beta_f32, dst_bf16.get(), CUDA_R_16BF, ldc, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16); + to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff, src1_ncols, stream); + return; + } +#endif + if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 ggml_cuda_pool_alloc src0_as_f16(ctx.pool(id)); diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index b9baee1b..8383f2d3 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -579,9 +579,16 @@ static __global__ void dequantize_block_iq4_ks(const void * __restrict__ vx, dst const uint8_t * q4 = x[i].qs + 16*ib + 4*il; const float d = scale * ((x[i].scales[ib] & 254) - 127); const int8_t * values = iq4k_values + ((x[i].scales[ib] & 1) << 4); - for (int j = 0; j < 4; ++j) { - y[j+ 0] = d * values[q4[j] & 0xf]; - y[j+16] = d * values[q4[j] >> 4]; + if constexpr (std::is_same_v) { + for (int j = 0; j < 4; ++j) { + y[j+ 0] = __float2bfloat16(d * values[q4[j] & 0xf]); + y[j+16] = __float2bfloat16(d * values[q4[j] >> 4]); + } + } else { + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d * values[q4[j] & 0xf]; + y[j+16] = d * values[q4[j] >> 4]; + } } } @@ -610,9 +617,16 @@ static __global__ void dequantize_block_iq4_kss(const void * __restrict__ vx, ds aux32[1] = ((aux32[0] >> 4) & 0x0f0f0f0f); aux32[0] &= 0x0f0f0f0f; const uint8_t * aux8 = (const uint8_t *)aux32; - for (int j = 0; j < 4; ++j) { - y[j+ 0] = d * values[aux8[j+0]]; - y[j+16] = d * values[aux8[j+4]]; + if constexpr (std::is_same_v) { + for (int j = 0; j < 4; ++j) { + y[j+ 0] = __float2bfloat16(d * values[aux8[j+0]]); + y[j+16] = __float2bfloat16(d * values[aux8[j+4]]); + } + } else { + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d * values[aux8[j+0]]; + y[j+16] = d * values[aux8[j+4]]; + } } } @@ -632,9 +646,16 @@ static __global__ void dequantize_block_iq4_k(const void * __restrict__ vx, dst_ const float d2 = d * (((x[i].scales_l[ib] >> 4) | ((sh << 2) & 0x30)) - 32); const int8_t * values1 = iq4k_values + 16*((x[i].extra >> (2*ib+0)) & 1); const int8_t * values2 = iq4k_values + 16*((x[i].extra >> (2*ib+1)) & 1); - for (int j = 0; j < 4; ++j) { - y[j+ 0] = d1 * values1[q4[j] & 0xf]; - y[j+16] = d2 * values2[q4[j] >> 4]; + if constexpr (std::is_same_v) { + for (int j = 0; j < 4; ++j) { + y[j+ 0] = __float2bfloat16(d1 * values1[q4[j] & 0xf]); + y[j+16] = __float2bfloat16(d2 * values2[q4[j] >> 4]); + } + } else { + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d1 * values1[q4[j] & 0xf]; + y[j+16] = d2 * values2[q4[j] >> 4]; + } } } @@ -656,12 +677,22 @@ static __global__ void dequantize_block_iq5_k(const void * __restrict__ vx, dst_ const uint8_t * qs = x[i].qs + 32*ib64 + 2*il; const uint8_t * qh = x[i].qh + 2*il; const uint8_t extra = x[i].extra >> 4*(ib64%4); - for (int j = 0; j < 2; ++j) { - const uint8_t h1 = qh[j] >> 2*(ib64%4), h2 = qh[j+16] >> 2*(ib64%4); - y[j+ 0] = dl1 * iq5nl_values[(qs[j+ 0] & 0xf) | ((h1 & 1) << 4) | ((extra << 5) & 0x20)]; - y[j+16] = dl2 * iq5nl_values[(qs[j+16] & 0xf) | ((h2 & 1) << 4) | ((extra << 4) & 0x20)]; - y[j+32] = dl3 * iq5nl_values[(qs[j+ 0] >> 4) | ((h1 & 2) << 3) | ((extra << 3) & 0x20)]; - y[j+48] = dl4 * iq5nl_values[(qs[j+16] >> 4) | ((h2 & 2) << 3) | ((extra << 2) & 0x20)]; + if constexpr (std::is_same_v) { + for (int j = 0; j < 2; ++j) { + const uint8_t h1 = qh[j] >> 2*(ib64%4), h2 = qh[j+16] >> 2*(ib64%4); + y[j+ 0] = __float2bfloat16(dl1 * iq5nl_values[(qs[j+ 0] & 0xf) | ((h1 & 1) << 4) | ((extra << 5) & 0x20)]); + y[j+16] = __float2bfloat16(dl2 * iq5nl_values[(qs[j+16] & 0xf) | ((h2 & 1) << 4) | ((extra << 4) & 0x20)]); + y[j+32] = __float2bfloat16(dl3 * iq5nl_values[(qs[j+ 0] >> 4) | ((h1 & 2) << 3) | ((extra << 3) & 0x20)]); + y[j+48] = __float2bfloat16(dl4 * iq5nl_values[(qs[j+16] >> 4) | ((h2 & 2) << 3) | ((extra << 2) & 0x20)]); + } + } else { + for (int j = 0; j < 2; ++j) { + const uint8_t h1 = qh[j] >> 2*(ib64%4), h2 = qh[j+16] >> 2*(ib64%4); + y[j+ 0] = dl1 * iq5nl_values[(qs[j+ 0] & 0xf) | ((h1 & 1) << 4) | ((extra << 5) & 0x20)]; + y[j+16] = dl2 * iq5nl_values[(qs[j+16] & 0xf) | ((h2 & 1) << 4) | ((extra << 4) & 0x20)]; + y[j+32] = dl3 * iq5nl_values[(qs[j+ 0] >> 4) | ((h1 & 2) << 3) | ((extra << 3) & 0x20)]; + y[j+48] = dl4 * iq5nl_values[(qs[j+16] >> 4) | ((h2 & 2) << 3) | ((extra << 2) & 0x20)]; + } } } @@ -689,10 +720,17 @@ static __global__ void dequantize_block_iq6_k(const void * __restrict__ vx, dst_ uint8_t q2 = (qs[j+16] & 0xf) | ((h2 & 0x03) << 4); uint8_t q3 = (qs[j+ 0] >> 4) | ((h1 & 0x0c) << 2); uint8_t q4 = (qs[j+16] >> 4) | ((h2 & 0x0c) << 2); - y[j+ 0] = dl1 * (iq6nl_values[q1] + (extra & 1 ? 1 : 0)); - y[j+16] = dl2 * (iq6nl_values[q2] + (extra & 2 ? 1 : 0)); - y[j+32] = dl3 * (iq6nl_values[q3] + (extra & 4 ? 1 : 0)); - y[j+48] = dl4 * (iq6nl_values[q4] + (extra & 8 ? 1 : 0)); + if constexpr (std::is_same_v) { + y[j+ 0] = __float2bfloat16(dl1 * (iq6nl_values[q1] + (extra & 1 ? 1 : 0))); + y[j+16] = __float2bfloat16(dl2 * (iq6nl_values[q2] + (extra & 2 ? 1 : 0))); + y[j+32] = __float2bfloat16(dl3 * (iq6nl_values[q3] + (extra & 4 ? 1 : 0))); + y[j+48] = __float2bfloat16(dl4 * (iq6nl_values[q4] + (extra & 8 ? 1 : 0))); + } else { + y[j+ 0] = dl1 * (iq6nl_values[q1] + (extra & 1 ? 1 : 0)); + y[j+16] = dl2 * (iq6nl_values[q2] + (extra & 2 ? 1 : 0)); + y[j+32] = dl3 * (iq6nl_values[q3] + (extra & 4 ? 1 : 0)); + y[j+48] = dl4 * (iq6nl_values[q4] + (extra & 8 ? 1 : 0)); + } } } @@ -713,11 +751,20 @@ static __global__ void dequantize_block_iq2_k(const void * __restrict__ vx, dst_ const float dl4 = d * (((x[i].scales[4*ib128+3] >> 4*(il/8)) & 0xf) - 8); const uint8_t * qs = x[i].qs + 32*ib128 + 2*il; const int16_t extra = x[i].extra >> (8*ib128 + (il/8)); - for (int j = 0; j < 2; ++j) { - y[j+ 0] = dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)]; - y[j+32] = dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 0) & 4)]; - y[j+64] = dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 2) & 4)]; - y[j+96] = dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 4) & 4)]; + if constexpr (std::is_same_v) { + for (int j = 0; j < 2; ++j) { + y[j+ 0] = __float2bfloat16(dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)]); + y[j+32] = __float2bfloat16(dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 0) & 4)]); + y[j+64] = __float2bfloat16(dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 2) & 4)]); + y[j+96] = __float2bfloat16(dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 4) & 4)]); + } + } else { + for (int j = 0; j < 2; ++j) { + y[j+ 0] = dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)]; + y[j+32] = dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 0) & 4)]; + y[j+64] = dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 2) & 4)]; + y[j+96] = dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 4) & 4)]; + } } } @@ -741,11 +788,20 @@ static __global__ void dequantize_block_iq2_ks(const void * __restrict__ vx, dst const float dl3 = d * (((x[i].scales[2*ib128+1] & 0xf) | ((extra >> 6) & 0x10)) - 16); const float dl4 = d * (((x[i].scales[2*ib128+1] >> 4) | ((extra >> 7) & 0x10)) - 16); const uint8_t * qs = x[i].qs + 32*ib128 + 2*il; - for (int j = 0; j < 2; ++j) { - y[j+ 0] = dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)]; - y[j+32] = dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 1) & 4)]; - y[j+64] = dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 0) & 4)]; - y[j+96] = dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 1) & 4)]; + if constexpr (std::is_same_v) { + for (int j = 0; j < 2; ++j) { + y[j+ 0] = __float2bfloat16(dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)]); + y[j+32] = __float2bfloat16(dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 1) & 4)]); + y[j+64] = __float2bfloat16(dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 0) & 4)]); + y[j+96] = __float2bfloat16(dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 1) & 4)]); + } + } else { + for (int j = 0; j < 2; ++j) { + y[j+ 0] = dl1 * iq2nl_values[((qs[j] >> 0) & 0x03) + ((extra << 2) & 4)]; + y[j+32] = dl2 * iq2nl_values[((qs[j] >> 2) & 0x03) + ((extra << 1) & 4)]; + y[j+64] = dl3 * iq2nl_values[((qs[j] >> 4) & 0x03) + ((extra >> 0) & 4)]; + y[j+96] = dl4 * iq2nl_values[((qs[j] >> 6) & 0x03) + ((extra >> 1) & 4)]; + } } } @@ -768,12 +824,22 @@ static __global__ void dequantize_block_iq3_k(const void * __restrict__ vx, dst_ const uint8_t * qs = x[i].qs + 32*ib128 + 2*il; const uint8_t * qh = x[i].qh + 2*il; const int16_t extra = x[i].extra >> (8*ib128 + (il/8)); - for (int j = 0; j < 2; ++j) { - const uint8_t h = qh[j] >> (4*(ib128%2)); - y[j+ 0] = dl1 * iq3nl_values[(((qs[j] >> 0) & 0x03) | ((h & 0x01) << 2)) + ((extra << 3) & 8)]; - y[j+32] = dl2 * iq3nl_values[(((qs[j] >> 2) & 0x03) | ((h & 0x02) << 1)) + ((extra << 1) & 8)]; - y[j+64] = dl3 * iq3nl_values[(((qs[j] >> 4) & 0x03) | ((h & 0x04) >> 0)) + ((extra >> 1) & 8)]; - y[j+96] = dl4 * iq3nl_values[(((qs[j] >> 6) & 0x03) | ((h & 0x08) >> 1)) + ((extra >> 3) & 8)]; + if constexpr (std::is_same_v) { + for (int j = 0; j < 2; ++j) { + const uint8_t h = qh[j] >> (4*(ib128%2)); + y[j+ 0] = __float2bfloat16(dl1 * iq3nl_values[(((qs[j] >> 0) & 0x03) | ((h & 0x01) << 2)) + ((extra << 3) & 8)]); + y[j+32] = __float2bfloat16(dl2 * iq3nl_values[(((qs[j] >> 2) & 0x03) | ((h & 0x02) << 1)) + ((extra << 1) & 8)]); + y[j+64] = __float2bfloat16(dl3 * iq3nl_values[(((qs[j] >> 4) & 0x03) | ((h & 0x04) >> 0)) + ((extra >> 1) & 8)]); + y[j+96] = __float2bfloat16(dl4 * iq3nl_values[(((qs[j] >> 6) & 0x03) | ((h & 0x08) >> 1)) + ((extra >> 3) & 8)]); + } + } else { + for (int j = 0; j < 2; ++j) { + const uint8_t h = qh[j] >> (4*(ib128%2)); + y[j+ 0] = dl1 * iq3nl_values[(((qs[j] >> 0) & 0x03) | ((h & 0x01) << 2)) + ((extra << 3) & 8)]; + y[j+32] = dl2 * iq3nl_values[(((qs[j] >> 2) & 0x03) | ((h & 0x02) << 1)) + ((extra << 1) & 8)]; + y[j+64] = dl3 * iq3nl_values[(((qs[j] >> 4) & 0x03) | ((h & 0x04) >> 0)) + ((extra >> 1) & 8)]; + y[j+96] = dl4 * iq3nl_values[(((qs[j] >> 6) & 0x03) | ((h & 0x08) >> 1)) + ((extra >> 3) & 8)]; + } } } @@ -1064,6 +1130,22 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) { return convert_to_bf16_cuda; case GGML_TYPE_F16: return convert_to_bf16_cuda; + case GGML_TYPE_IQ2_KS: + return dequantize_row_iq2_ks_cuda; + case GGML_TYPE_IQ2_K: + return dequantize_row_iq2_k_cuda; + case GGML_TYPE_IQ3_K: + return dequantize_row_iq3_k_cuda; + case GGML_TYPE_IQ4_KSS: + return dequantize_row_iq4_kss_cuda; + case GGML_TYPE_IQ4_KS: + return dequantize_row_iq4_ks_cuda; + case GGML_TYPE_IQ4_K: + return dequantize_row_iq4_k_cuda; + case GGML_TYPE_IQ5_K: + return dequantize_row_iq5_k_cuda; + case GGML_TYPE_IQ6_K: + return dequantize_row_iq6_k_cuda; default: return nullptr; } From f4ebf13b6a63ac1367bc392e24566d71c0b4c5b9 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 18 Mar 2025 07:44:43 +0100 Subject: [PATCH 042/132] Fix #261 (#262) Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda.cu | 65 ++++++++++++++++++++++--------------------- 1 file changed, 33 insertions(+), 32 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index a8f7383a..01f98594 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -1267,40 +1267,41 @@ static void ggml_cuda_op_mul_mat_cublas( #ifdef GGML_CUDA_IQK_FORCE_BF16 if (ggml_is_quantized(src0->type) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) { to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src0->type); - GGML_ASSERT(to_bf16_cuda != nullptr); - size_t ne = row_diff*ne00; - ggml_cuda_pool_alloc src0_as_bf16(ctx.pool(id), ne); - to_bf16_cuda(src0_dd_i, src0_as_bf16.get(), row_diff, ne00, stream); + if (to_bf16_cuda) { + size_t ne = row_diff*ne00; + ggml_cuda_pool_alloc src0_as_bf16(ctx.pool(id), ne); + to_bf16_cuda(src0_dd_i, src0_as_bf16.get(), row_diff, ne00, stream); - ggml_cuda_pool_alloc src1_as_bf16(ctx.pool(id)); - if (src1->type != GGML_TYPE_BF16) { - size_t ne = src1_ncols*ne10; - src1_as_bf16.alloc(ne); - to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type); - GGML_ASSERT(to_bf16_cuda != nullptr); - to_bf16_cuda(src1_ddf_i, src1_as_bf16.get(), src1_ncols, ne10, stream); + ggml_cuda_pool_alloc src1_as_bf16(ctx.pool(id)); + if (src1->type != GGML_TYPE_BF16) { + size_t ne = src1_ncols*ne10; + src1_as_bf16.alloc(ne); + to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type); + GGML_ASSERT(to_bf16_cuda != nullptr); + to_bf16_cuda(src1_ddf_i, src1_as_bf16.get(), src1_ncols, ne10, stream); + } + const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get(); + const nv_bfloat16 * src0_ptr = src0_as_bf16.get(); + + ggml_cuda_pool_alloc dst_bf16(ctx.pool(id), row_diff*src1_ncols); + + const float alpha_f32 = 1.0f; + const float beta_f32 = 0.0f; + + CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); + CUBLAS_CHECK( + cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha_f32, src0_ptr, CUDA_R_16BF, ne00, + src1_ptr, CUDA_R_16BF, ne10, + &beta_f32, dst_bf16.get(), CUDA_R_16BF, ldc, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16); + to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff, src1_ncols, stream); + return; } - const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get(); - const nv_bfloat16 * src0_ptr = src0_as_bf16.get(); - - ggml_cuda_pool_alloc dst_bf16(ctx.pool(id), row_diff*src1_ncols); - - const float alpha_f32 = 1.0f; - const float beta_f32 = 0.0f; - - CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); - CUBLAS_CHECK( - cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, - row_diff, src1_ncols, ne10, - &alpha_f32, src0_ptr, CUDA_R_16BF, ne00, - src1_ptr, CUDA_R_16BF, ne10, - &beta_f32, dst_bf16.get(), CUDA_R_16BF, ldc, - CUBLAS_COMPUTE_32F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16); - to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff, src1_ncols, stream); - return; } #endif From 68a5b60408b1085d2b2ed5de75e004ee23f8ddb9 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 18 Mar 2025 15:40:47 +0100 Subject: [PATCH 043/132] Make Q8_0 KV cache work with mla=2,fa on CUDA (#264) Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda.cu | 8 ++++ ggml/src/ggml-cuda/binbcast.cu | 23 ++++++++- ggml/src/ggml-cuda/concat.cu | 24 ++++++++-- ggml/src/ggml-cuda/cpy.cu | 86 ++++++++++++++++++++-------------- src/llama.cpp | 22 +++++++-- 5 files changed, 117 insertions(+), 46 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 01f98594..453bde0c 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3395,6 +3395,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons if (op->op == GGML_OP_MOE_FUSED_UP_GATE && a->type != op->src[1]->type) { return false; } + //================================================================== + //if (ggml_is_quantized(a->type) && ggml_is_quantized(b->type)) { + // return false; + //} + //================================================================== if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16 && !ggml_is_quantized(a->type)) { return false; } @@ -3496,6 +3501,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { return true; } + if (ggml_is_quantized(src0_type) && (src1_type == GGML_TYPE_F16 || src1_type == GGML_TYPE_F32)) { + return true; + } if (ggml_is_contiguous(op->src[0]) && ggml_are_same_shape(op->src[0], op->src[1])) { if (src1_type == GGML_TYPE_F16 || src1_type == GGML_TYPE_BF16 || src1_type == GGML_TYPE_F32) { return true; diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index a2508350..f217488e 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -282,7 +282,28 @@ static void ggml_cuda_op_bin_bcast( } void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_cuda_op_bin_bcast>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream()); + GGML_ASSERT(dst->type == dst->src[0]->type); + if (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16) { + ggml_cuda_op_bin_bcast>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream()); + return; + } + auto src = dst->src[0]; + auto bs = ggml_blck_size(src->type); + auto ts = ggml_type_size(src->type); + if (src->nb[0] != ts || ts*(src->ne[0]/bs) % 2 != 0) { + fprintf(stderr, "%s: unsupported case type = %s, nb[0] = %zu, type_size = %zu\n", __func__, ggml_type_name(src->type), src->nb[0], ts); + GGML_ABORT("fatal error"); + } + auto aux_src = *src; + aux_src.type = GGML_TYPE_F16; + aux_src.ne[0] = ts*(src->ne[0]/bs)/2; + aux_src.nb[0] = 2; + auto aux_dst = *dst; + aux_dst.type = GGML_TYPE_F16; + aux_dst.ne[0] = ts*(dst->ne[0]/bs)/2; + aux_dst.nb[0] = 2; + aux_dst.src[0] = &aux_src; + ggml_cuda_op_bin_bcast>(&aux_dst, &aux_src, &aux_dst, nullptr, dst->src[0]->data, dst->data, ctx.stream()); } void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index b40617f6..f89ce974 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -209,6 +209,10 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if (dim == 0 && src0->nb[0] == ggml_type_size(src0->type) && src1->nb[0] == ggml_type_size(src1->type) && src0->nb[1] % sizeof(float) == 0 && src1->nb[1] % sizeof(float) == 0) { + auto bs = ggml_blck_size(dst->type); + auto ts = ggml_type_size(dst->type); + auto ne00_eff = (src0->ne[0]/bs)*ts/sizeof(float); + auto ne0_eff = (dst->ne[0]/bs)*ts/sizeof(float); if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { //if (dst->ne[1] >= 65536 || dst->ne[2] >= 65536) { // fprintf(stderr, "%s: ne1 = %ld, ne2 = %ld exceed max. blocks when computing %s\n", __func__, dst->ne[1], dst->ne[2], dst->name); @@ -217,25 +221,35 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const float * src0_d = (const float *)src0->data; const float * src1_d = (const float *)src1->data; float * dst_d = (float *)dst->data; + //printf("%s(%s, %s): %ld %zu %zu %ld %zu %zu\n", __func__, src0->name, src1->name, src0->ne[0], src0->nb[0], src0->nb[1], dst->ne[0], dst->nb[0], dst->nb[1]); for (int i3 = 0; i3 < dst->ne[3]; i3++) { concat_f32_cuda( src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4), dst_d + i3 * ( dst->nb[3] / 4), - src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2], - dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dim, stream); + ne00_eff, src0->ne[1], src0->ne[2], + ne0_eff, dst->ne[1], dst->ne[2], dim, stream); + //src0->nb[1]/sizeof(float), src0->ne[1], src0->ne[2], + //dst->nb[1]/sizeof(float), dst->ne[1], dst->ne[2], dim, stream); + //src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2], + //dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dim, stream); } } else { + //printf("%s(not contiguous): %s(%s) and %s(%s)\n", __func__, src0->name, ggml_type_name(src0->type), src1->name, ggml_type_name(src1->type)); + auto ne10_eff = (src1->ne[0]/bs)*ts/sizeof(float); dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]); concat_f32_non_cont<<>>( (const char *)src0->data, (const char *)src1->data, ( char *)dst->data, - src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2], src0->ne[3], + ne00_eff, src0->ne[1], src0->ne[2], src0->ne[3], + //src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2], src0->ne[3], sizeof(float), src0->nb[1], src0->nb[2], src0->nb[3], - src1->ne[0]*src1->nb[0]/sizeof(float), src1->ne[1], src1->ne[2], src1->ne[3], + ne10_eff, src1->ne[1], src1->ne[2], src1->ne[3], + //src1->ne[0]*src1->nb[0]/sizeof(float), src1->ne[1], src1->ne[2], src1->ne[3], sizeof(float), src1->nb[1], src1->nb[2], src1->nb[3], - dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dst->ne[3], + ne0_eff, dst->ne[1], dst->ne[2], dst->ne[3], + //dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dst->ne[3], sizeof(float), dst->nb[1], dst->nb[2], dst->nb[3], dim); } return; diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 44eba389..3f2d0089 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -66,25 +66,30 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, cpy_1(cx + x_offset, cdst + dst_offset); } -//static __global__ void cpy_q8_0_f32(const char * cx, float * dst, const int ne, -// const int ne00, const int ne01, const int ne02, const int nb01, const int nb02, const int nb03) { -// const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; -// -// if (i >= ne) { -// return; -// } -// -// const int64_t i03 = i/(ne00 * ne01 * ne02); -// const int64_t i02 = (i - i03*ne00*ne01*ne02) / (ne00*ne01); -// const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00; -// const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00; -// -// const block_q8_0 * q8 = (const block_q8_0 *)(cx + i01*nb01 + i02*nb02 + i03*nb03); -// const int ib = i00/QK8_0; -// const int iq = i00%QK8_0; -// -// dst[i00*ne01 + i01 + i02*ne00*ne01 + i03*ne00*ne01*ne02] = __half2float(q8[ib].d)*q8[ib].qs[iq]; -//} +template +static __global__ void k_cpy_q8_0_to_float(const char * cx, dst_t * dst, const int ne, + const int ne00, const int ne01, const int ne02, const int nb01, const int nb02, const int nb03) { + const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= ne) { + return; + } + + const int64_t i03 = i/(ne00 * ne01 * ne02); + const int64_t i02 = (i - i03*ne00*ne01*ne02) / (ne00*ne01); + const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00; + const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00; + + const block_q8_0 * q8 = (const block_q8_0 *)(cx + i01*nb01 + i02*nb02 + i03*nb03); + const int ib = i00/QK8_0; + const int iq = i00%QK8_0; + + if constexpr (std::is_same_v) { + dst[i00 + i01*ne00 + i02*ne00*ne01 + i03*ne00*ne01*ne02] = __float2bfloat16(__half2float(q8[ib].d)*q8[ib].qs[iq]); + } else { + dst[i00 + i01*ne00 + i02*ne00*ne01 + i03*ne00*ne01*ne02] = __half2float(q8[ib].d)*q8[ib].qs[iq]; + } +} static __global__ void k_transpose_q8_0(const char * cx, char * cdst, const int ne10, const int ne11, const int ne12, @@ -532,23 +537,26 @@ static void transpose_q8_0(ggml_backend_cuda_context & ctx, const ggml_tensor * (const char *)src->data, (char *)dst->data, dst->ne[0], dst->ne[1], dst->ne[2], src->nb[0], src->nb[2], src->nb[3], dst->nb[1], dst->nb[2], dst->nb[3]); +} - //auto ne = ggml_nelements(dst); - //ggml_cuda_pool_alloc dst_f32(ctx.pool(), ne); - //const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - //auto aux_src = *dst; - //aux_src.nb[0] = sizeof(float); - //aux_src.nb[1] = aux_src.nb[0]*aux_src.ne[0]; - //aux_src.nb[2] = aux_src.nb[1]*aux_src.ne[1]; - //aux_src.nb[3] = aux_src.nb[2]*aux_src.ne[2]; - //cpy_q8_0_f32<<>> - // ((const char *)src->data, dst_f32.get(), ne, - // src->ne[1], src->ne[0], src->ne[2], src->nb[0], src->nb[2], src->nb[3]); - //CUDA_CHECK(cudaGetLastError()); - //aux_src.type = GGML_TYPE_F32; - //ggml_cpy_f32_q8_0_cuda((const char *)dst_f32.get(), (char *)dst->data, ne, dst->ne[0], dst->ne[1], dst->ne[2], - // aux_src.nb[0], aux_src.nb[1], aux_src.nb[2], aux_src.nb[3], - // dst->ne[0], dst->ne[1], dst->ne[2], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], stream); +static void copy_q8_0_to_float(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) { + auto stream = ctx.stream(); + auto num_blocks = ggml_nelements(dst)/QK8_0; + if (dst->type == GGML_TYPE_F16) { + k_cpy_q8_0_to_float<<>>((const char *)src->data, (half *)dst->data, ggml_nelements(dst), + src->ne[0], src->ne[1], src->ne[2], src->nb[1], src->nb[2], src->nb[3]); + } + else if (dst->type == GGML_TYPE_F32) { + k_cpy_q8_0_to_float<<>>((const char *)src->data, (float *)dst->data, ggml_nelements(dst), + src->ne[0], src->ne[1], src->ne[2], src->nb[1], src->nb[2], src->nb[3]); + } + else if (dst->type == GGML_TYPE_BF16) { + k_cpy_q8_0_to_float<<>>((const char *)src->data, (nv_bfloat16 *)dst->data, ggml_nelements(dst), + src->ne[0], src->ne[1], src->ne[2], src->nb[1], src->nb[2], src->nb[3]); + } + else { + GGML_ABORT("fatal error"); + } } void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) { @@ -607,8 +615,13 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { + ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && + (src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_BF16 || src1->type == GGML_TYPE_F32)) { + copy_q8_0_to_float(ctx, src0, src1); } else if (ggml_is_contiguous(src0) && ggml_are_same_shape(src0, src1)) { if (src1->type == GGML_TYPE_F16) { auto to_fp16 = ggml_get_to_fp16_cuda(src0->type); @@ -670,6 +683,9 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { return (void*) cpy_f32_f16; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { return (void*) cpy_f32_f16; + } else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && + (src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_BF16 || src1->type == GGML_TYPE_F32)) { + return (void*)copy_q8_0_to_float; } else if (ggml_is_contiguous(src0) && ggml_are_same_shape(src0, src1)) { if (src1->type == GGML_TYPE_F16) { auto to_fp16 = ggml_get_to_fp16_cuda(src0->type); diff --git a/src/llama.cpp b/src/llama.cpp index 605e5d36..76039f8e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -13771,9 +13771,21 @@ struct llm_build_context { auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, 1, kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)); + // There is still an issue with one or more of the ops GGML_OP_REPEAT, GGML_OP_CONCAT, GGML_OP_CPY on CUDA when + // the KV cache is quantized. Hence, in that case we will simply use fp16 for now. + // The downside of the following line is that fp16 will be used even if attention is computed on the CPU + // if the build is with CUDA enabled. + auto kv_type = lctx.backends.size() == 1 && lctx.backends.front() == lctx.backend_cpu ? kv_self.kv_l[il]->type : GGML_TYPE_F16; + ggml_tensor repeater; repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_max_head; repeater.ne[3] = 1; - auto k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater); + ggml_tensor * k_rope; + if (kv_cache_rope->type == kv_type) { + k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater); + } else { + auto kv_cache_rope_f16 = ggml_cast(ctx0, kv_cache_rope, GGML_TYPE_F16); + k_rope = ggml_repeat(ctx0, kv_cache_rope_f16, &repeater); + } cb(k_rope, "k_rope", il); auto q = ggml_concat(ctx0, q_nope, q_rope, 0); @@ -13796,15 +13808,15 @@ struct llm_build_context { ggml_row_size(kv_f32->type, n_embd_head_qk_nope)); cb(v_f32, "v_f32", il); - auto v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type); - cb(v, "v", il); - auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_max_head, ggml_row_size(kv_f32->type, n_max_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0); cb(k_nope_f32, "k_nope_f32", il); - auto k_nope = ggml_cast(ctx0, k_nope_f32, kv_self.kv_l[il]->type); + auto v = ggml_cast(ctx0, v_f32, kv_type); + cb(v, "v", il); + + auto k_nope = ggml_cast(ctx0, k_nope_f32, kv_type); cb(k_nope, "k_nope", il); ggml_build_forward_expand(gf, k_nope); From 8e549b42346f039be01a06495c6469d66d1c8926 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 18 Mar 2025 15:41:05 +0100 Subject: [PATCH 044/132] Allow q8_0 cache on the CPU for FlashMLA-2 (#265) Co-authored-by: Iwan Kawrakow --- ggml/src/ggml.c | 64 ++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 56 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index a904464e..1552d91b 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -10569,6 +10569,58 @@ static void ggml_compute_forward_dup_q( const struct ggml_compute_params * params, struct ggml_tensor * dst) { + int64_t nrows = ggml_nrows(dst); + int ith = params->ith; + int nth = params->nth; + + if (dst->type == GGML_TYPE_Q8_0 && dst->src[0]->type == GGML_TYPE_Q8_0 && + ggml_are_same_shape(dst, dst->src[0])) { + + // we assume src is transposed and that's why we are here + + GGML_ASSERT(dst->ne[0] % QK8_0 == 0); + + struct ggml_tensor * const src = dst->src[0]; + GGML_ASSERT(src->nb[1] == sizeof(block_q8_0)); + + float aux[QK8_0]; + + int64_t n_per_thread = (nrows + nth - 1)/nth; + int64_t first_row = ith*n_per_thread; + if (first_row >= nrows) return; + int64_t last_row = MIN(first_row + n_per_thread, nrows); + + int64_t nblock = dst->ne[0] / QK8_0; + for (int64_t ir = first_row; ir < last_row; ++ir) { + int64_t i3 = ir/(dst->ne[1]*dst->ne[2]); + int64_t i2 = (ir - i3*dst->ne[1]*dst->ne[2])/dst->ne[1]; + int64_t i1 = ir - i3*dst->ne[1]*dst->ne[2] - i2*dst->ne[1]; + int ib0 = i1/QK8_0; + int iq0 = i1%QK8_0; + for (int ib = 0; ib < nblock; ++ib) { + block_q8_0 * dst_q8 = (block_q8_0 *)((char *)dst->data + i1*dst->nb[1] + i2*dst->nb[2] + i3*dst->nb[3]); + float amax = 0; + for (int j = 0; j < QK8_0; ++j) { + int64_t i0 = ib*QK8_0 + j; + const block_q8_0 * src_q8 = (const block_q8_0 *)((const char *)src->data + i0*src->nb[0] + i2*src->nb[2] + i3*src->nb[3]); + float xi = GGML_FP16_TO_FP32(src_q8[ib0].d) * src_q8[ib0].qs[iq0]; + aux[j] = xi; + xi = fabsf(xi); + amax = MAX(amax, xi); + } + float d = amax/127; + dst_q8[ib].d = GGML_FP32_TO_FP16(d); + if (d > 0) { + float id = 1/d; + for (int j = 0; j < QK8_0; ++j) dst_q8[ib].qs[j] = roundf(id*aux[j]); + } else { + memset(dst_q8[ib].qs, 0, QK8_0); + } + } + } + return; + } + GGML_ASSERT(dst->type == GGML_TYPE_F32); struct ggml_tensor * src0 = dst->src[0]; GGML_ASSERT(src0->ne[0] == dst->ne[0] && src0->nb[0] == ggml_type_size(src0->type)); @@ -10576,10 +10628,6 @@ static void ggml_compute_forward_dup_q( ggml_to_float_t to_float = type_traits[src0->type].to_float; GGML_ASSERT(to_float != NULL); - int64_t nrows = ggml_nrows(dst); - int ith = params->ith; - int nth = params->nth; - int64_t n_per_thread = (nrows + nth - 1)/nth; int64_t first_row = ith*n_per_thread; if (first_row >= nrows) return; @@ -10607,13 +10655,13 @@ static void ggml_compute_forward_dup( const struct ggml_tensor * src0 = dst->src[0]; - if (src0->type == dst->type) { - ggml_compute_forward_dup_bytes(params, dst); + if (ggml_is_quantized(src0->type)) { + ggml_compute_forward_dup_q(params, dst); return; } - if (ggml_is_quantized(src0->type)) { - ggml_compute_forward_dup_q(params, dst); + if (src0->type == dst->type) { + ggml_compute_forward_dup_bytes(params, dst); return; } From c3b75c531c5d5e295d5c3ba846bb37def72df3a7 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 19 Mar 2025 13:03:59 +0100 Subject: [PATCH 045/132] Prevent FlashMLA-1 from running on CUDA (#268) as it is not supported. Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 453bde0c..e488c9bc 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3579,6 +3579,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons return (op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) || (op->src[1]->type == GGML_TYPE_Q8_0 && op->src[2]->type == GGML_TYPE_Q8_0); } + if (op->src[1]->ne[0] > 256) { + return false; + } if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) { return true; } From 22c84a126f50146a851641ccaa6e8a24f0985d79 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 19 Mar 2025 15:47:24 +0100 Subject: [PATCH 046/132] Fix ggml_compute_forward_dup_q (#269) I broke it with PR #265. I was testing with a model where the wk_b and wk_v tensors were present, so didn't need to be computed, so didn't notice that the change I made to ggml_compute_forward_dup_q breaks that computation. Co-authored-by: Iwan Kawrakow --- ggml/src/ggml.c | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 1552d91b..faf1902d 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -10576,6 +10576,11 @@ static void ggml_compute_forward_dup_q( if (dst->type == GGML_TYPE_Q8_0 && dst->src[0]->type == GGML_TYPE_Q8_0 && ggml_are_same_shape(dst, dst->src[0])) { + if (dst->src[0]->nb[0] == sizeof(block_q8_0) && dst->nb[0] == sizeof(block_q8_0)) { + ggml_compute_forward_dup_bytes(params, dst); + return; + } + // we assume src is transposed and that's why we are here GGML_ASSERT(dst->ne[0] % QK8_0 == 0); From 127c6ee6493a3084995d754d987f0240ffdffe6a Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 19 Mar 2025 19:17:03 +0100 Subject: [PATCH 047/132] Honor mmap setting when using tensor overrides (#270) Co-authored-by: Iwan Kawrakow --- src/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index 76039f8e..03139e41 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8015,7 +8015,7 @@ static bool llm_load_tensors( // only the mmap region containing the tensors in the model is mapped to the backend buffer // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size - if (ml.use_mmap && use_mmap_buffer && buft == llama_default_buffer_type_cpu(true)) { + if (ml.use_mmap && use_mmap_buffer && (buft == llama_default_buffer_type_cpu(true) || buft == ggml_backend_cpu_buffer_type())) { for (uint32_t idx = 0; idx < ml.files.size(); idx++) { void * addr = nullptr; size_t first, last; From b8d1fac97b756968b86b470d44bb1026ded7157a Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 21 Mar 2025 07:23:36 +0100 Subject: [PATCH 048/132] Convert models to row-interleaved quants using the quantize tool (#272) * Repack a model with the quantize tool * WIP * Fixed various issues As we don't have a way to tell if a repacked quant has been modified, I had to remove the modification at the expense of a slight decrease in performance. This affects q8_0_r8, q8_KV_r8, q8_k_r8 on Zen4, and q4_0_r8 on ARM. * Create wk_b and wv_b as Q8_0_R8 if the wkv_b type is interleaved * Fix GCC 13.3 compilation error * Another one * Add missing include --------- Co-authored-by: Iwan Kawrakow --- examples/quantize/quantize.cpp | 3 + ggml/src/ggml.c | 48 ++++++++++++--- ggml/src/iqk/iqk_mul_mat.cpp | 71 ++++++++++++--------- ggml/src/iqk/iqk_quantize.cpp | 69 ++++++++++++++++----- ggml/src/iqk/iqk_quantize.h | 3 + include/llama.h | 1 + src/llama.cpp | 109 +++++++++++++++++++++++++++++++-- 7 files changed, 246 insertions(+), 58 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 89de794b..84ea38d4 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -145,6 +145,7 @@ static void usage(const char * executable) { printf(" --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor.\n"); printf(" --token-embedding-type ggml_type: use this ggml_type for the token_embd.weight tensor.\n\n"); printf(" --custom-q regex1=type1,regex2=type2...: use this to specify custom quantization type rules.\n\n"); + printf(" --repack Repack all tensors to the corresponding _r4/8 variant if available.\n\n"); printf("Additional specific tensor quantization types used in the custom quant scheme 'CQS (default is Q2_K):\n"); printf(" --attn-q-type ggml_type: use this ggml_type for the attn_q.weight tensor.\n"); printf(" --attn-k-type ggml_type: use this ggml_type for the attn_k.weight tensor.\n"); @@ -331,6 +332,8 @@ int main(int argc, char ** argv) { params.quantize_output_tensor = false; } else if (strcmp(argv[arg_idx], "--ignore-imatrix-rules") == 0) { params.ignore_imatrix_rules = true; + } else if (strcmp(argv[arg_idx], "--repack") == 0) { + params.only_repack = true; } else if (strcmp(argv[arg_idx], "--output-tensor-type") == 0) { if (arg_idx < argc-1) { params.output_tensor_type = parse_ggml_type(argv[++arg_idx]); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index faf1902d..a2bdc156 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1756,6 +1756,15 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) { return type_traits[type]; } +static inline int ggml_packed_rows(enum ggml_type type) { + return type == GGML_TYPE_BF16_R16 ? 16 + : type == GGML_TYPE_Q8_K_R8 || type == GGML_TYPE_Q8_KV_R8 || + type == GGML_TYPE_Q8_0_R8 || type == GGML_TYPE_Q4_0_R8 || + type == GGML_TYPE_IQ4_XS_R8 ? 8 + : type >= GGML_TYPE_Q4_0_R8 && type <= GGML_TYPE_Q8_K_R8 ? 4 + : 1; +} + // // simd mappings // @@ -10119,9 +10128,11 @@ static void ggml_compute_forward_dup_f32( } // parallelize by rows + int n_packed = ggml_packed_rows(dst->type); + GGML_ASSERT(dst->ne[1] % n_packed == 0); const int nr = ne01; // number of rows per thread - const int dr = (nr + nth - 1) / nth; + const int dr = n_packed*((nr/n_packed + nth - 1) / nth); // row range for this thread const int ir0 = dr * ith; const int ir1 = MIN(ir0 + dr, nr); @@ -10173,10 +10184,10 @@ static void ggml_compute_forward_dup_f32( for (int i03 = 0; i03 < ne03; i03++) { for (int i02 = 0; i02 < ne02; i02++) { id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { + for (int i01 = ir0; i01 < ir1; i01 += n_packed) { const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - quantize_row_q(src0_ptr, dst_ptr + id, ne00); - id += rs; + quantize_row_q(src0_ptr, dst_ptr + id, ne00*n_packed); + id += rs*n_packed; } id += rs * (ne01 - ir1); } @@ -10441,10 +10452,15 @@ static void ggml_compute_forward_dup_bytes( // parallelize by rows const int nr = ne01; + const int n_packed = ggml_packed_rows(dst->type); + GGML_ASSERT(nr%n_packed == 0); + const int nrp = nr/n_packed; // number of rows per thread - const int dr = (nr + nth - 1) / nth; + const int drp = (nrp + nth - 1) / nth; + const int dr = drp*n_packed; // row range for this thread const int ir0 = dr * ith; + if (ir0 >= nr) return; const int ir1 = MIN(ir0 + dr, nr); if (src0->type == dst->type && @@ -10569,10 +10585,19 @@ static void ggml_compute_forward_dup_q( const struct ggml_compute_params * params, struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_quantized(dst->src[0]->type)); + int64_t nrows = ggml_nrows(dst); int ith = params->ith; int nth = params->nth; + if (dst->src[0]->type == dst->type && + dst->src[0]->nb[0] == ggml_type_size(dst->type) && + dst->nb[0] == ggml_type_size(dst->type)) { + ggml_compute_forward_dup_bytes(params, dst); + return; + } + if (dst->type == GGML_TYPE_Q8_0 && dst->src[0]->type == GGML_TYPE_Q8_0 && ggml_are_same_shape(dst, dst->src[0])) { @@ -10626,6 +10651,10 @@ static void ggml_compute_forward_dup_q( return; } + if (dst->type != GGML_TYPE_F32) { + printf("%s: %s -> %s is of type %s\n", __func__, dst->src[0]->name, dst->name, ggml_type_name(dst->type)); + GGML_ABORT("fatal error"); + } GGML_ASSERT(dst->type == GGML_TYPE_F32); struct ggml_tensor * src0 = dst->src[0]; GGML_ASSERT(src0->ne[0] == dst->ne[0] && src0->nb[0] == ggml_type_size(src0->type)); @@ -10633,12 +10662,15 @@ static void ggml_compute_forward_dup_q( ggml_to_float_t to_float = type_traits[src0->type].to_float; GGML_ASSERT(to_float != NULL); - int64_t n_per_thread = (nrows + nth - 1)/nth; + int n_packed = ggml_packed_rows(src0->type); + GGML_ASSERT(src0->ne[1] % n_packed == 0); + + int64_t n_per_thread = n_packed*((nrows/n_packed + nth - 1)/nth); int64_t first_row = ith*n_per_thread; if (first_row >= nrows) return; int64_t last_row = MIN(first_row + n_per_thread, nrows); - for (int64_t ir = first_row; ir < last_row; ++ir) { + for (int64_t ir = first_row; ir < last_row; ir += n_packed) { int64_t i03 = ir/(src0->ne[1]*src0->ne[2]); int64_t i02 = (ir - i03*src0->ne[1]*src0->ne[2])/src0->ne[1]; int64_t i01 = ir - i03*src0->ne[1]*src0->ne[2] - i02*src0->ne[1]; @@ -10649,7 +10681,7 @@ static void ggml_compute_forward_dup_q( const char * q = (const char *)src0->data + i03*src0->nb[3] + i02*src0->nb[2] + i01*src0->nb[1]; char * f = ( char *)dst->data + i3* dst->nb[3] + i2* dst->nb[2] + i1* dst->nb[1]; - to_float((const void *)q, (float *)f, src0->ne[0]); + to_float((const void *)q, (float *)f, src0->ne[0]*n_packed); } } diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 14cc64db..8b6d6b1c 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -4434,14 +4434,17 @@ inline __m256i qx_r8_q8_dot_product(const __m256i * qx, const int8_t * y) { return sumi; } inline __m256i q8_0_r8_dot_product(const uint8_t * x, const int8_t * y, __m256i * qx) { - qx[0] = _mm256_loadu_si256((const __m256i *)x+0); - qx[1] = _mm256_loadu_si256((const __m256i *)x+1); - qx[2] = _mm256_loadu_si256((const __m256i *)x+2); - qx[3] = _mm256_loadu_si256((const __m256i *)x+3); - qx[4] = _mm256_loadu_si256((const __m256i *)x+4); - qx[5] = _mm256_loadu_si256((const __m256i *)x+5); - qx[6] = _mm256_loadu_si256((const __m256i *)x+6); - qx[7] = _mm256_loadu_si256((const __m256i *)x+7); + for (int i = 0; i < 8; ++i) { + qx[i] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)x+i), _mm256_set1_epi8(127)); + } + //qx[0] = _mm256_loadu_si256((const __m256i *)x+0); + //qx[1] = _mm256_loadu_si256((const __m256i *)x+1); + //qx[2] = _mm256_loadu_si256((const __m256i *)x+2); + //qx[3] = _mm256_loadu_si256((const __m256i *)x+3); + //qx[4] = _mm256_loadu_si256((const __m256i *)x+4); + //qx[5] = _mm256_loadu_si256((const __m256i *)x+5); + //qx[6] = _mm256_loadu_si256((const __m256i *)x+6); + //qx[7] = _mm256_loadu_si256((const __m256i *)x+7); return qx_r8_q8_dot_product(qx, y); } template @@ -4496,6 +4499,7 @@ static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn for (int j = 0; j < 8; ++j) { qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+j)), _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+j), 1); + qx[j] = _mm512_add_epi8(qx[j], _mm512_set1_epi8(127)); } for (int iy = 0; iy < nrc_y; ++iy) { auto sumi = qx_r8_q8_dot_product(qx, q8.y[iy][ib4].qs+32*k); @@ -4512,6 +4516,7 @@ static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn for (int j = 0; j < 8; ++j) { qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[ib].qs+j)), _mm256_loadu_si256((const __m256i *)q8h[ib].qs+j), 1); + qx[j] = _mm512_add_epi8(qx[j], _mm512_set1_epi8(127)); } for (int iy = 0; iy < nrc_y; ++iy) { auto qy = (const block_q8_1 *)q8.y[iy]; @@ -6347,6 +6352,11 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn auto s1 = _mm256_sign_epi8(qx[1], qx[1]); auto s2 = _mm256_sign_epi8(qx[2], qx[2]); auto s3 = _mm256_sign_epi8(qx[3], qx[3]); +#else + qx[0] = _mm256_add_epi8(qx[0], _mm256_set1_epi8(127)); + qx[1] = _mm256_add_epi8(qx[1], _mm256_set1_epi8(127)); + qx[2] = _mm256_add_epi8(qx[2], _mm256_set1_epi8(127)); + qx[3] = _mm256_add_epi8(qx[3], _mm256_set1_epi8(127)); #endif for (int iy = 0; iy < nrc_y; ++iy) { auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+ib); @@ -6425,6 +6435,11 @@ static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const Data auto s1 = _mm256_sign_epi8(qx[1], qx[1]); auto s2 = _mm256_sign_epi8(qx[2], qx[2]); auto s3 = _mm256_sign_epi8(qx[3], qx[3]); +#else + qx[0] = _mm256_add_epi8(qx[0], _mm256_set1_epi8(127)); + qx[1] = _mm256_add_epi8(qx[1], _mm256_set1_epi8(127)); + qx[2] = _mm256_add_epi8(qx[2], _mm256_set1_epi8(127)); + qx[3] = _mm256_add_epi8(qx[3], _mm256_set1_epi8(127)); #endif for (int iy = 0; iy < nrc_y; ++iy) { auto y128 = _mm_loadu_si128((const __m128i*)q8y[iy]+ib); @@ -14305,8 +14320,8 @@ struct Q4_0_R8_Dequantizer { float32x4x2_t scales = { vcvt_f32_f16(vget_low_f16(scales16)), vcvt_f32_f16(vget_high_f16(scales16)) }; for (int j = 0; j < 4; ++j) { auto bits = vld1q_u8_x2(iq4[4*ib4+k].qs + 32*j); - //bits.val[0] = veorq_u8(m88, bits.val[0]); - //bits.val[1] = veorq_u8(m88, bits.val[1]); + bits.val[0] = veorq_u8(m88, bits.val[0]); + bits.val[1] = veorq_u8(m88, bits.val[1]); qx[2*j+0] = vshlq_n_u8(bits.val[0], 4); qx[2*j+1] = vandq_u8(bits.val[0], m4); qx[2*j+8] = vshlq_n_u8(bits.val[1], 4); @@ -15305,12 +15320,12 @@ struct HelperQ80R8 : public BaseHelper { m1 = _mm256_unpackhi_epi64(t0, t1); m2 = _mm256_unpacklo_epi64(t2, t3); m3 = _mm256_unpackhi_epi64(t2, t3); -#ifdef HAVE_FANCY_SIMD - m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); - m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); - m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); - m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); -#endif +//#ifdef HAVE_FANCY_SIMD +// m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); +// m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); +// m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); +// m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); +//#endif _mm256_storeu_si256((__m256i *)y[ib].qs + 0, m0); _mm256_storeu_si256((__m256i *)y[ib].qs + 1, m1); _mm256_storeu_si256((__m256i *)y[ib].qs + 2, m2); @@ -15327,12 +15342,12 @@ struct HelperQ80R8 : public BaseHelper { m1 = _mm256_unpackhi_epi64(t0, t1); m2 = _mm256_unpacklo_epi64(t2, t3); m3 = _mm256_unpackhi_epi64(t2, t3); -#ifdef HAVE_FANCY_SIMD - m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); - m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); - m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); - m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); -#endif +//#ifdef HAVE_FANCY_SIMD +// m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); +// m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); +// m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); +// m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); +//#endif _mm256_storeu_si256((__m256i *)y[ib].qs + 4, m0); _mm256_storeu_si256((__m256i *)y[ib].qs + 5, m1); _mm256_storeu_si256((__m256i *)y[ib].qs + 6, m2); @@ -15424,12 +15439,12 @@ struct HelperQ8KVR8 : public BaseHelper { m1 = _mm256_unpackhi_epi64(t0, t1); m2 = _mm256_unpacklo_epi64(t2, t3); m3 = _mm256_unpackhi_epi64(t2, t3); -#ifdef HAVE_FANCY_SIMD - m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); - m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); - m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); - m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); -#endif +//#ifdef HAVE_FANCY_SIMD +// m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); +// m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); +// m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); +// m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); +//#endif _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+0, m0); _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+1, m1); _mm256_storeu_si256((__m256i *)y[ix].qs + 4*ib+2, m2); diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index fb6a5db4..5e657f4a 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -6766,9 +6767,7 @@ struct Modify { modify_func_t mod_func; int nrows; }; -} - -bool iqk_modify_tensor(struct ggml_tensor * tensor) { +const Modify * get_modify_info(ggml_type type) { static const std::unordered_map k_mod_map = { #ifdef __ARM_NEON { GGML_TYPE_Q4_0_R8, {modify_q4_0_r8, 8} }, @@ -6779,10 +6778,31 @@ bool iqk_modify_tensor(struct ggml_tensor * tensor) { { GGML_TYPE_Q8_KV_R8, {modify_q8_KV_r8, 8} }, #endif }; - auto it = k_mod_map.find(tensor->type); - if (it == k_mod_map.end()) return false; + auto it = k_mod_map.find(type); + return it != k_mod_map.end() ? &it->second : nullptr; +} +bool is_forbidden_tensor(const std::string& name) { + static const std::string kTokenEmbd{"token_embd.weight"}; + if (name == kTokenEmbd) return true; + //if (auto pos = name.find("attn_kv_b.weight"); pos != std::string::npos) return true; + return false; +} +} - auto& m = it->second; +bool iqk_should_modify_tensor([[maybe_unused]] const struct ggml_tensor * tensor) { + return false; + //if (is_forbidden_tensor(tensor->name)) return false; + //auto mptr = get_modify_info(tensor->type); + //return mptr ? true : false; +} + +bool iqk_modify_tensor(struct ggml_tensor * tensor) { + return false; + auto mptr = get_modify_info(tensor->type); + if (!mptr) return false; + if (is_forbidden_tensor(std::string{tensor->name})) return false; + + auto& m = *mptr; int nrows = ggml_nrows(tensor); int nchunks = nrows/m.nrows; int max_thread = std::max(1, int(std::thread::hardware_concurrency()/2)); @@ -6805,12 +6825,8 @@ bool iqk_modify_tensor(struct ggml_tensor * tensor) { return true; } -void iqk_repack_tensor(struct ggml_tensor * tensor) { - constexpr int kChunk = 8; - if (!tensor) return; - if (!ggml_is_contiguous(tensor)) return; - if (strncmp(tensor->name, "token_embd.weight", GGML_MAX_NAME) == 0) return; - if (tensor->ne[1] % 4) return; +namespace { +const Repack * get_repack_info(ggml_type type) { static const std::unordered_map k_map = { { GGML_TYPE_IQ2_K, { GGML_TYPE_IQ2_K_R4, 4, (Repack::repack_func)repack_iq2_k} }, { GGML_TYPE_IQ3_K, { GGML_TYPE_IQ3_K_R4, 4, (Repack::repack_func)repack_iq3_k} }, @@ -6841,12 +6857,30 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) { { GGML_TYPE_F16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16} }, #endif }; + auto it = k_map.find(type); + return it != k_map.end() ? &it->second : nullptr; +} +} - auto it = k_map.find(tensor->type); - if (it == k_map.end()) return; - if (tensor->ne[1] % it->second.num_rows) return; +int iqk_repacked_type(const struct ggml_tensor * tensor) { + if (!ggml_is_contiguous(tensor)) return (int)tensor->type; + if (is_forbidden_tensor(tensor->name)) return (int)tensor->type; + auto rptr = get_repack_info(tensor->type); + return rptr && tensor->ne[1] % rptr->num_rows == 0 ? (int)rptr->new_type : (int)tensor->type; +} - auto& r = it->second; +void iqk_repack_tensor(struct ggml_tensor * tensor) { + constexpr int kChunk = 8; + if (!tensor) return; + if (!ggml_is_contiguous(tensor)) return; + if (is_forbidden_tensor(tensor->name)) return; + if (tensor->ne[1] % 4) return; + + auto rptr = get_repack_info(tensor->type); + if (!rptr) return; + if (tensor->ne[1] % rptr->num_rows) return; + + auto& r = *rptr; auto nrows = ggml_nrows(tensor); @@ -6871,7 +6905,8 @@ void iqk_repack_tensor(struct ggml_tensor * tensor) { int last_row = std::min(first_row + chunkSize*r.num_rows, nrows); for (int row = first_row; row < last_row; row += r.num_rows) { std::memcpy(qtmp.data(), data + row*row_size, r.num_rows*row_size); - r.repack(r.num_rows, n_per_row, qtmp.data(), data + row*row_size, true); + //r.repack(r.num_rows, n_per_row, qtmp.data(), data + row*row_size, true); + r.repack(r.num_rows, n_per_row, qtmp.data(), data + row*row_size, false); } } }; diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index d447705b..dd148f2e 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -245,6 +245,9 @@ void repack_bf16_bf16_r16(const void * GGML_RESTRICT src, void * GGML_RESTRICT d void iqk_repack_tensor(struct ggml_tensor * tensor); bool iqk_modify_tensor(struct ggml_tensor * tensor); +int iqk_repacked_type(const struct ggml_tensor * tensor); // int instead of ggml_type so we don't need to include ggml.h +bool iqk_should_modify_tensor(const struct ggml_tensor * tensor); + // So we can re-pack Microsoft's BitNet I2_S quants void dequantize_row_ms_i2s(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); diff --git a/include/llama.h b/include/llama.h index 5e86cb68..66e9af02 100644 --- a/include/llama.h +++ b/include/llama.h @@ -416,6 +416,7 @@ extern "C" { bool pure; // quantize all tensors to the default type bool keep_split; // quantize to the same number of shards bool ignore_imatrix_rules; // If set to true, the built-in rules for refusing to quantize into certain quants without imatrix are ignored + bool only_repack; // Only repack tensors void * imatrix; // pointer to importance matrix data void * kv_overrides; // pointer to vector containing overrides void * custom_quants; // pointer to vector containing custom quantization rules diff --git a/src/llama.cpp b/src/llama.cpp index 03139e41..a459cb00 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8194,7 +8194,8 @@ static bool llm_load_tensors( auto wk_b_f32_t = ggml_cont(ctx, wk_b_f32_tview); wk_b_f32_t->data = (char *)wk_b_f32->data + ggml_nbytes(wk_b_f32); - auto new_type = ggml_is_quantized(wkv_b.type) ? GGML_TYPE_Q8_0 : wkv_b.type; + auto new_type = ggml_is_quantized(wkv_b.type) ? + wkv_b.type >= GGML_TYPE_Q4_0_R8 && wkv_b.type <= GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_0_R8 : GGML_TYPE_Q8_0 : wkv_b.type; auto wk_b = ggml_cast(ctx, wk_b_f32_t, new_type); wk_b->data = (char *)wk_b_f32_t->data + ggml_nbytes(wk_b_f32_t); @@ -8218,6 +8219,9 @@ static bool llm_load_tensors( ggml_set_name(l.computed_wk_b.get(), name.c_str()); ggml_backend_buffer_set_usage(l.computed_wk_b->buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); ggml_backend_tensor_set(l.computed_wk_b.get(), wk_b->data, 0, ggml_nbytes(wk_b)); + if (ggml_backend_buffer_is_host(l.computed_wk_b->buffer)) { + iqk_modify_tensor(l.computed_wk_b.get()); + } l.wk_b = l.computed_wk_b.get(); @@ -8243,6 +8247,9 @@ static bool llm_load_tensors( ggml_set_name(l.computed_wv_b.get(), name.c_str()); ggml_backend_buffer_set_usage(l.computed_wv_b->buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); ggml_backend_tensor_set(l.computed_wv_b.get(), wv_b->data, 0, ggml_nbytes(wv_b)); + if (ggml_backend_buffer_is_host(l.computed_wv_b->buffer)) { + iqk_modify_tensor(l.computed_wv_b.get()); + } l.wv_b = l.computed_wv_b.get(); @@ -17140,11 +17147,48 @@ static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const floa return new_size; } +static llama_ftype repacked_ftype(llama_ftype ftype) { + static std::unordered_map k_map = { + { LLAMA_FTYPE_MOSTLY_Q4_0, LLAMA_FTYPE_MOSTLY_Q4_0_R8 }, + { LLAMA_FTYPE_MOSTLY_Q8_0, LLAMA_FTYPE_MOSTLY_Q8_0_R8 }, + { LLAMA_FTYPE_MOSTLY_Q5_0, LLAMA_FTYPE_MOSTLY_Q5_0_R4 }, + { LLAMA_FTYPE_MOSTLY_Q2_K, LLAMA_FTYPE_MOSTLY_Q2_K_R4 }, + { LLAMA_FTYPE_MOSTLY_Q3_K_S, LLAMA_FTYPE_MOSTLY_Q3_K_R4 }, + { LLAMA_FTYPE_MOSTLY_Q3_K_M, LLAMA_FTYPE_MOSTLY_Q3_K_R4 }, + { LLAMA_FTYPE_MOSTLY_Q3_K_L, LLAMA_FTYPE_MOSTLY_Q3_K_R4 }, + { LLAMA_FTYPE_MOSTLY_Q4_K_S, LLAMA_FTYPE_MOSTLY_Q4_K_R4 }, + { LLAMA_FTYPE_MOSTLY_Q4_K_M, LLAMA_FTYPE_MOSTLY_Q4_K_R4 }, + { LLAMA_FTYPE_MOSTLY_Q5_K_S, LLAMA_FTYPE_MOSTLY_Q5_K_R4 }, + { LLAMA_FTYPE_MOSTLY_Q5_K_M, LLAMA_FTYPE_MOSTLY_Q5_K_R4 }, + { LLAMA_FTYPE_MOSTLY_Q6_K, LLAMA_FTYPE_MOSTLY_Q6_K_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ2_XXS, LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ2_XS, LLAMA_FTYPE_MOSTLY_IQ2_XS_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ3_XXS, LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ1_S, LLAMA_FTYPE_MOSTLY_IQ1_S_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ4_NL, LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ3_S, LLAMA_FTYPE_MOSTLY_IQ3_S_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ2_M, LLAMA_FTYPE_MOSTLY_IQ2_M_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ4_XS, LLAMA_FTYPE_MOSTLY_IQ4_XS_R8 }, + { LLAMA_FTYPE_MOSTLY_IQ1_M, LLAMA_FTYPE_MOSTLY_IQ1_M_R4 }, + { LLAMA_FTYPE_MOSTLY_Q6_0, LLAMA_FTYPE_MOSTLY_Q6_0_R4 }, + { LLAMA_FTYPE_MOSTLY_BF16, LLAMA_FTYPE_MOSTLY_BF16_R16 }, + { LLAMA_FTYPE_MOSTLY_IQ2_BN, LLAMA_FTYPE_MOSTLY_IQ2_BN_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ2_K, LLAMA_FTYPE_MOSTLY_IQ2_K_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ3_K, LLAMA_FTYPE_MOSTLY_IQ3_K_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ4_K, LLAMA_FTYPE_MOSTLY_IQ4_K_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ5_K, LLAMA_FTYPE_MOSTLY_IQ5_K_R4 }, + { LLAMA_FTYPE_MOSTLY_IQ4_KS, LLAMA_FTYPE_MOSTLY_IQ4_KS_R4 }, + { LLAMA_FTYPE_MOSTLY_Q8_KV, LLAMA_FTYPE_MOSTLY_Q8_KV_R8 }, + }; + if (auto it = k_map.find(ftype); it != k_map.end()) return it->second; + return ftype; +} + static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { ggml_type default_type; llama_ftype ftype = params->ftype; - switch (params->ftype) { + switch (ftype) { case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break; case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break; case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break; @@ -17256,7 +17300,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s ftype = model.ftype; } const std::unordered_map> * imatrix_data = nullptr; - if (params->imatrix) { + if (!params->only_repack && params->imatrix) { imatrix_data = static_cast>*>(params->imatrix); if (imatrix_data) { LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size())); @@ -17278,7 +17322,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // copy the KV pairs from the input file gguf_set_kv (ctx_out, ml.meta); gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV - gguf_set_val_u32(ctx_out, "general.file_type", ftype); // TODO: use LLM_KV // Remove split metadata gguf_remove_key(ctx_out, ml.llm_kv(LLM_KV_SPLIT_NO).c_str()); @@ -17303,9 +17346,20 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } } + bool is_repacked = ml.ftype >= LLAMA_FTYPE_MOSTLY_Q4_0_R8 && ml.ftype <= LLAMA_FTYPE_MOSTLY_Q8_K_R8; + int n_to_repack = 0, n_to_modify = 0; for (int i = 0; i < ml.n_tensors; ++i) { const struct ggml_tensor * meta = ml.get_tensor_meta(i); + if (params->only_repack) { + auto repacked_type = (ggml_type)iqk_repacked_type(meta); + if (repacked_type != meta->type) { + ++n_to_repack; + } else if (!is_repacked) { + if (iqk_should_modify_tensor(meta)) ++n_to_modify; + } + } + const std::string name = ggml_get_name(meta); // TODO: avoid hardcoded tensor names - use the TN_* constants @@ -17317,6 +17371,18 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } } + if (params->only_repack) { + if (n_to_repack == 0 && n_to_modify == 0) { + printf("=========================== %s: nothing to do for only_repack option\n", __func__); + return; + } + ftype = repacked_ftype(model.ftype); + printf("===================== Model ftype: %s: Repacked ftype: %s\n", llama_model_ftype_name(model.ftype).c_str(), + llama_model_ftype_name(ftype).c_str()); + } + + gguf_set_val_u32(ctx_out, "general.file_type", ftype); // TODO: use LLM_KV + qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer; // sanity checks @@ -17457,6 +17523,36 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s void * new_data; size_t new_size; + if (params->only_repack) { + ggml_type repacked_type = (ggml_type)iqk_repacked_type(tensor); + bool modify = !is_repacked && iqk_should_modify_tensor(tensor); + if (modify || repacked_type != tensor->type) { + new_type = repacked_type; + new_size = ggml_nbytes(tensor); + if ((int)work.size() < new_size) work.resize(new_size); + new_data = work.data(); + + auto aux_tensor = *tensor; + aux_tensor.data = work.data(); + std::memcpy(aux_tensor.data, tensor->data, new_size); + + if (repacked_type != tensor->type) { + iqk_repack_tensor(&aux_tensor); + GGML_ASSERT(aux_tensor.type == repacked_type); + } else { + bool did_modify = iqk_modify_tensor(&aux_tensor); + GGML_ASSERT(did_modify); + } + } + else { + new_type = tensor->type; + new_size = ggml_nbytes(tensor); + new_data = tensor->data; + } + LLAMA_LOG_INFO("size = %8.3f MB, type = %s\n", new_size/1024.0/1024.0, ggml_type_name(new_type)); + goto QuantizationDone; + } + if (quantize) { new_type = default_type; if (new_type == GGML_TYPE_BF16_R16 && strcmp(tensor->name, "token_embd.weight") == 0) { @@ -17562,7 +17658,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s new_type == GGML_TYPE_IQ1_S_R4|| new_type == GGML_TYPE_IQ1_M_R4|| (new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight")) || - (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0))) { + (new_type == GGML_TYPE_Q2_K && ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0))) { LLAMA_LOG_ERROR("\n\n============================================================\n"); LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name); LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n"); @@ -17727,6 +17823,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0); } + +QuantizationDone:; total_size_org += ggml_nbytes(tensor); total_size_new += new_size; @@ -18051,6 +18149,7 @@ struct llama_model_quantize_params llama_model_quantize_default_params() { /*.pure =*/ false, /*.keep_split =*/ false, /*.ignore_imatrix_rules =*/ false, + /*.only_repack =*/ false, /*.imatrix =*/ nullptr, /*.kv_overrides =*/ nullptr, /*.custom_quants =*/ nullptr, From ddc8eee10ee9216de57429167e6f74e618577d93 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 21 Mar 2025 07:24:22 +0100 Subject: [PATCH 049/132] FlashMLA-3: the best of both worlds (CPU only) (#273) * Repack a model with the quantize tool * WIP * Fixed various issues As we don't have a way to tell if a repacked quant has been modified, I had to remove the modification at the expense of a slight decrease in performance. This affects q8_0_r8, q8_KV_r8, q8_k_r8 on Zen4, and q4_0_r8 on ARM. * Create wk_b and wv_b as Q8_0_R8 if the wkv_b type is interleaved * Fix GCC 13.3 compilation error * Another one * Add missing include * FlashMLA-3: the best of both worlds - CPU only --------- Co-authored-by: Iwan Kawrakow --- src/llama.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index a459cb00..33b69389 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -13760,7 +13760,7 @@ struct llm_build_context { ggml_tensor * kqv; - if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && (pp_opt || lctx.cparams.mla_attn > 2)) { + if (lctx.cparams.mla_attn > 1 && lctx.cparams.flash_attn && pp_opt) { // PP for mla=2,3 auto kv_cache_nope = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, kv_self.kv_l[il]->nb[1], 0); @@ -13869,7 +13869,7 @@ struct llm_build_context { ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0); cb(q, "q", il); - if (lctx.cparams.flash_attn && lctx.cparams.mla_attn == 1) { + if (lctx.cparams.flash_attn && lctx.cparams.mla_attn == 1 || lctx.cparams.mla_attn == 3) { ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); From 022660f7aba973c149e011eac5c4b3dfea02618d Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 21 Mar 2025 10:51:37 +0100 Subject: [PATCH 050/132] Specify tensor name regex for tensors to be repacked (#274) Co-authored-by: Iwan Kawrakow --- examples/quantize/quantize.cpp | 14 ++++++++++++ include/llama.h | 1 + src/llama.cpp | 41 ++++++++++++++++++++++++++++++---- 3 files changed, 52 insertions(+), 4 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 84ea38d4..1d00c874 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -146,6 +146,7 @@ static void usage(const char * executable) { printf(" --token-embedding-type ggml_type: use this ggml_type for the token_embd.weight tensor.\n\n"); printf(" --custom-q regex1=type1,regex2=type2...: use this to specify custom quantization type rules.\n\n"); printf(" --repack Repack all tensors to the corresponding _r4/8 variant if available.\n\n"); + printf(" --repack-pattern Comma separated list of regexs to use for matching tensor names to be repacked.\n\n"); printf("Additional specific tensor quantization types used in the custom quant scheme 'CQS (default is Q2_K):\n"); printf(" --attn-q-type ggml_type: use this ggml_type for the attn_q.weight tensor.\n"); printf(" --attn-k-type ggml_type: use this ggml_type for the attn_k.weight tensor.\n"); @@ -327,6 +328,8 @@ int main(int argc, char ** argv) { std::vector kv_overrides; std::vector custom_quants; + std::vector repack_patterns; + for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) { if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) { params.quantize_output_tensor = false; @@ -334,6 +337,13 @@ int main(int argc, char ** argv) { params.ignore_imatrix_rules = true; } else if (strcmp(argv[arg_idx], "--repack") == 0) { params.only_repack = true; + } else if (strcmp(argv[arg_idx], "--repack-pattern") == 0) { + if (arg_idx < argc-1) { + auto p = string_split(argv[++arg_idx], ','); + repack_patterns.insert(repack_patterns.end(), p.begin(), p.end()); + } else { + usage(argv[0]); + } } else if (strcmp(argv[arg_idx], "--output-tensor-type") == 0) { if (arg_idx < argc-1) { params.output_tensor_type = parse_ggml_type(argv[++arg_idx]); @@ -431,6 +441,10 @@ int main(int argc, char ** argv) { } } + if (!repack_patterns.empty()) { + params.repack_pattern = &repack_patterns; + } + if (argc - arg_idx < 2) { printf("%s: bad arguments\n", argv[0]); usage(argv[0]); diff --git a/include/llama.h b/include/llama.h index 66e9af02..1b9c47e9 100644 --- a/include/llama.h +++ b/include/llama.h @@ -420,6 +420,7 @@ extern "C" { void * imatrix; // pointer to importance matrix data void * kv_overrides; // pointer to vector containing overrides void * custom_quants; // pointer to vector containing custom quantization rules + void * repack_pattern; // pointer to a vector containing regexes to be used for matching tensor names. Can be null } llama_model_quantize_params; // grammar types diff --git a/src/llama.cpp b/src/llama.cpp index 33b69389..dfe445b8 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17348,20 +17348,39 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s bool is_repacked = ml.ftype >= LLAMA_FTYPE_MOSTLY_Q4_0_R8 && ml.ftype <= LLAMA_FTYPE_MOSTLY_Q8_K_R8; int n_to_repack = 0, n_to_modify = 0; + const std::vector * repack_pattern = nullptr; + if (params->repack_pattern) repack_pattern = (const std::vector *)params->repack_pattern; + for (int i = 0; i < ml.n_tensors; ++i) { const struct ggml_tensor * meta = ml.get_tensor_meta(i); + const std::string name = ggml_get_name(meta); + if (params->only_repack) { auto repacked_type = (ggml_type)iqk_repacked_type(meta); + bool repack = false, modify = false; if (repacked_type != meta->type) { - ++n_to_repack; + repack = true; } else if (!is_repacked) { - if (iqk_should_modify_tensor(meta)) ++n_to_modify; + if (iqk_should_modify_tensor(meta)) { + modify = true; + } } + if ((repack || modify) && repack_pattern) { + bool found = false; + for (auto& r : *repack_pattern) { + std::regex pattern(r); + if (std::regex_search(name, pattern)) { + found = true; + break; + } + } + if (!found) repack = modify = false; + } + if (repack) ++n_to_repack; + else if (modify) ++n_to_modify; } - const std::string name = ggml_get_name(meta); - // TODO: avoid hardcoded tensor names - use the TN_* constants if (name.find("attn_v.weight") != std::string::npos || name.find("attn_qkv.weight") != std::string::npos) { @@ -17526,6 +17545,19 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (params->only_repack) { ggml_type repacked_type = (ggml_type)iqk_repacked_type(tensor); bool modify = !is_repacked && iqk_should_modify_tensor(tensor); + if ((modify || repacked_type != tensor->type) && repack_pattern) { + bool found = false; + for (auto& r : *repack_pattern) { + std::regex pattern(r); + if (std::regex_search(tensor->name, pattern)) { + found = true; break; + } + } + if (!found) { + modify = false; + repacked_type = tensor->type; + } + } if (modify || repacked_type != tensor->type) { new_type = repacked_type; new_size = ggml_nbytes(tensor); @@ -18153,6 +18185,7 @@ struct llama_model_quantize_params llama_model_quantize_default_params() { /*.imatrix =*/ nullptr, /*.kv_overrides =*/ nullptr, /*.custom_quants =*/ nullptr, + /*.repack_pattern =*/ nullptr, }; return result; From 3d6e25c82db5510df483185b8a20f0ce01136dd7 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 21 Mar 2025 13:23:01 +0100 Subject: [PATCH 051/132] Fix bug: missing parentheses in logical expression (#275) This results in GGGGGGGGGGGGG when generating with mla = 3, fa = 0. Co-authored-by: Iwan Kawrakow --- src/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index dfe445b8..186cb5a5 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -13869,7 +13869,7 @@ struct llm_build_context { ggml_tensor * q = ggml_concat(ctx0, q_nope2, ggml_permute(ctx0, q_rope, 0, 2, 1, 3), 0); cb(q, "q", il); - if (lctx.cparams.flash_attn && lctx.cparams.mla_attn == 1 || lctx.cparams.mla_attn == 3) { + if (lctx.cparams.flash_attn && (lctx.cparams.mla_attn == 1 || lctx.cparams.mla_attn == 3)) { ggml_tensor * kv_cache_lora = ggml_view_2d(ctx0, kv_self.kv_l[il], kv_lora_rank, n_kv, ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0); From d8584a1bbe561f681d55830688eef33f43155af6 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 22 Mar 2025 08:05:10 +0100 Subject: [PATCH 052/132] Add Gemma3 support (text only) (#276) * WIP Gemma3: not working * gemma3: build_gemma3 seems to be working now * Revert changes to convert_hf_to_gguf.py It wasn't working, so I guess, it is better to leave the conversion up tp upstream. --------- Co-authored-by: Iwan Kawrakow --- gguf-py/gguf/constants.py | 2 + src/llama.cpp | 208 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 210 insertions(+) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 037837da..93c614d6 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -207,6 +207,7 @@ class MODEL_ARCH(IntEnum): MINICPM = auto() GEMMA = auto() GEMMA2 = auto() + GEMMA3 = auto() STARCODER2 = auto() MAMBA = auto() XVERSE = auto() @@ -338,6 +339,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.MINICPM: "minicpm", MODEL_ARCH.GEMMA: "gemma", MODEL_ARCH.GEMMA2: "gemma2", + MODEL_ARCH.GEMMA3: "gemma3", MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.XVERSE: "xverse", diff --git a/src/llama.cpp b/src/llama.cpp index 186cb5a5..54e0bb01 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -203,6 +203,7 @@ enum llm_arch { LLM_ARCH_MINICPM, LLM_ARCH_GEMMA, LLM_ARCH_GEMMA2, + LLM_ARCH_GEMMA3, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, LLM_ARCH_XVERSE, @@ -250,6 +251,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_MINICPM, "minicpm" }, { LLM_ARCH_GEMMA, "gemma" }, { LLM_ARCH_GEMMA2, "gemma2" }, + { LLM_ARCH_GEMMA3, "gemma3" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_XVERSE, "xverse" }, @@ -1056,6 +1058,26 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, }, }, + { + LLM_ARCH_GEMMA3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, { LLM_ARCH_STARCODER2, { @@ -2313,6 +2335,7 @@ struct llama_hparams { uint32_t n_layer; uint32_t n_rot; uint32_t n_swa = 0; // sliding window attention (SWA) + uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head uint32_t n_expert = 0; @@ -2342,7 +2365,9 @@ struct llama_hparams { float rope_attn_factor = 1.0f; float rope_freq_base_train; + float rope_freq_base_train_swa; float rope_freq_scale_train; + float rope_freq_scale_train_swa; uint32_t n_ctx_orig_yarn; float rope_yarn_log_mul; @@ -4963,6 +4988,10 @@ static void llm_load_hparams( } hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; + // by default assume that the sliding-window layers use the same scaling type as the non-sliding-window layers + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); // non-transformer models do not have attention heads @@ -5310,6 +5339,28 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_GEMMA3: + { + hparams.n_swa_pattern = 6; + + hparams.rope_freq_base_train_swa = 10000.0f; + hparams.rope_freq_scale_train_swa = 1.0f; + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 26: model.type = e_model::MODEL_2B; break; + case 34: model.type = e_model::MODEL_4B; break; + case 48: model.type = e_model::MODEL_12B; break; + case 62: model.type = e_model::MODEL_27B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + + hparams.f_attention_scale = model.type == e_model::MODEL_27B + ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) + : 1.0f / std::sqrt(float(hparams.n_embd_head_k)); + } break; case LLM_ARCH_STARCODER2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -7411,6 +7462,38 @@ static bool llm_load_tensors( layer.ffn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}); } } break; + case LLM_ARCH_GEMMA3: + { + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, + llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + } break; case LLM_ARCH_STARCODER2: { model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -12664,6 +12747,126 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_gemma3() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + const int64_t n_embd_head_k = hparams.n_embd_head_k; + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + if (batch.token) { + inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); + cb(inpL, "inp_scaled", -1); + } + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + // gemma3 requires different mask for layers using sliding window (SWA) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(true); + struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true); + + // "5-to-1 interleaved attention" + // 5 layers of local attention followed by 1 layer of global attention + static const int sliding_window_pattern = 6; + + for (int il = 0; il < n_layer; ++il) { + const bool is_sliding = (il + 1) % sliding_window_pattern; + const float freq_base_l = is_sliding ? 10000.0f : freq_base; + const float freq_scale_l = is_sliding ? 1.0f : freq_scale; + struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens); + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens); + Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il); + } + + cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "attn_post_norm", il); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); + cb(sa_out, "sa_out", il); + + cur = llm_build_norm(ctx0, sa_out, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, cb, -1); + cb(cur, "ffn_post_norm", -1); + + cur = ggml_add(ctx0, cur, sa_out); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } struct ggml_cgraph * build_starcoder2() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -14996,6 +15199,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_gemma2(); } break; + case LLM_ARCH_GEMMA3: + { + result = llm.build_gemma3(); + } break; case LLM_ARCH_STARCODER2: { result = llm.build_starcoder2(); @@ -18826,6 +19033,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_PHI3: case LLM_ARCH_GEMMA: case LLM_ARCH_GEMMA2: + case LLM_ARCH_GEMMA3: case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: case LLM_ARCH_GPTNEOX: From 13ecc5332ec4c5ad38d930019a1b6a7211c88e50 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 22 Mar 2025 16:58:30 +0100 Subject: [PATCH 053/132] Fighting with cmake (#279) Co-authored-by: Iwan Kawrakow --- ggml/src/CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 7013526c..67b8f36d 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -421,6 +421,11 @@ if (GGML_CUDA) set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ... endif() endif() + if (NOT GGML_MUSA) + set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_INCLUDES 0) + set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_LIBRARIES 0) + set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_OBJECTS 0) + endif() else() message(WARNING "CUDA not found") endif() From 6028362ef6112c035001fef4313ab3697ef79a30 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 22 Mar 2025 18:17:51 +0100 Subject: [PATCH 054/132] Native build ooption for CUDA when GGML_NATIVE is set (#280) Co-authored-by: Iwan Kawrakow --- ggml/src/CMakeLists.txt | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 67b8f36d..f7f15fbd 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -297,10 +297,12 @@ if (GGML_CUDA) # 60 == FP16 CUDA intrinsics # 61 == integer CUDA intrinsics # 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster - if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) - set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75") + if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24") + set(CMAKE_CUDA_ARCHITECTURES "native") + elseif (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) + set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75;80") else() - set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75") + set(CMAKE_CUDA_ARCHITECTURES "50;61;70;75;80") #set(CMAKE_CUDA_ARCHITECTURES "OFF") # use this to compile much faster, but only F16 models work endif() endif() From dd5ebd0e3d42e871ac398deaaa3f60edaa78a7eb Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 23 Mar 2025 07:24:43 +0100 Subject: [PATCH 055/132] Test transparent huge pages on Linux (#278) * Adding ability to use THP on Linux * Use the actual page size4 used for mmap also in munmap * Add -thp to llama-bench --------- Co-authored-by: Iwan Kawrakow --- common/common.cpp | 6 +++ common/common.h | 1 + examples/llama-bench/llama-bench.cpp | 34 ++++++++++++-- include/llama.h | 1 + src/llama.cpp | 70 ++++++++++++++++++++++++---- 5 files changed, 99 insertions(+), 13 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index e62944b9..5b3d60b3 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -993,6 +993,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.use_mmap = false; return true; } + if (arg == "-thp" || arg == "--transparent-huge-pages") { + params.use_thp = true; + return true; + } if (arg == "--numa") { CHECK_ARG std::string value(argv[i]); @@ -2316,6 +2320,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & mparams.use_mlock = params.use_mlock; mparams.check_tensors = params.check_tensors; mparams.repack_tensors = params.repack_tensors; + mparams.use_thp = params.use_thp; if (params.kv_overrides.empty()) { mparams.kv_overrides = NULL; } else { @@ -3371,6 +3376,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs); fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false"); fprintf(stream, "repack: %s # default: false\n", params.repack_tensors ? "true" : "false"); + fprintf(stream, "use_thp: %s # default: false\n", params.use_thp ? "true" : "false"); fprintf(stream, "penalize_nl: %s # default: false\n", sparams.penalize_nl ? "true" : "false"); fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride); diff --git a/common/common.h b/common/common.h index f6a55885..79c5fe96 100644 --- a/common/common.h +++ b/common/common.h @@ -194,6 +194,7 @@ struct gpt_params { bool warmup = true; // warmup run bool check_tensors = false; // validate tensor data bool repack_tensors = false; // repack tensors if interleaved variant is available + bool use_thp = false; // use transparent huge pages (linux only) std::string cache_type_k = "f16"; // KV cache data type for the K std::string cache_type_v = "f16"; // KV cache data type for the V diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 167525bc..24d2e185 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -248,6 +248,7 @@ struct cmd_params { bool warmup; bool repack = false; bool fmoe = false; + bool use_thp = false; output_formats output_format; output_formats output_format_stderr; }; @@ -281,6 +282,7 @@ static const cmd_params cmd_params_defaults = { /* verbose */ false, /* warmup */ true, /* repack */ false, + /* use_thp */ false, /* fmoe */ false, /* output_format */ MARKDOWN, /* output_format_stderr */ NONE, @@ -320,6 +322,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0"); printf(" -w, --warmup <0|1> (default: %s)\n", cmd_params_defaults.warmup ? "1" : "0"); printf(" -rtr, --run-time-repack <0|1> (default: %s)\n", cmd_params_defaults.repack ? "1" : "0"); + printf(" -thp, --transparent-huge-pages <0|1> (default: %s)\n", cmd_params_defaults.use_thp? "1" : "0"); printf(" -ot, --override-tensor pattern (default: none)\n"); printf(" -fmoe, --fused-moe <0|1> (default: %s)\n", cmd_params_defaults.fmoe? "1" : "0"); printf("\n"); @@ -691,6 +694,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } params.repack = std::stoi(argv[i]); + } else if (arg == "-thp" || arg == "--transparent-huge-pages") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.use_thp = std::stoi(argv[i]); } else if (arg == "-fmoe" || arg == "--fused-moe") { if (++i >= argc) { invalid_param = true; @@ -781,6 +790,7 @@ struct cmd_params_instance { bool embeddings; bool repack = false; bool fmoe = false; + bool use_thp = false; const llama_model_tensor_buft_override* buft_overrides; llama_model_params to_llama_mparams() const { @@ -795,6 +805,7 @@ struct cmd_params_instance { mparams.tensor_split = tensor_split.data(); mparams.use_mmap = use_mmap; mparams.repack_tensors = repack; + mparams.use_thp = use_thp; mparams.tensor_buft_overrides = buft_overrides; return mparams; @@ -808,6 +819,7 @@ struct cmd_params_instance { main_gpu == other.main_gpu && use_mmap == other.use_mmap && repack == other.repack && + use_thp == other.use_thp && tensor_split == other.tensor_split; } @@ -882,6 +894,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .embeddings = */ embd, /* .repack = */ params.repack, /* .fmoe = */ params.fmoe, + /* .use_thp = */ params.use_thp, /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); @@ -915,6 +928,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .embeddings = */ embd, /* .repack = */ params.repack, /* .fmoe = */ params.fmoe, + /* .use_thp = */ params.use_thp, /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); @@ -948,6 +962,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .embeddings = */ embd, /* .repack = */ params.repack, /* .fmoe = */ params.fmoe, + /* .use_thp = */ params.use_thp, /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); @@ -981,6 +996,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .embeddings = */ embd, /* .repack = */ params.repack, /* .fmoe = */ params.fmoe, + /* .use_thp = */ params.use_thp, /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); @@ -1025,6 +1041,7 @@ struct test { bool embeddings; bool repack = false; bool fmoe = false; + bool use_thp = false; int n_prompt; int n_gen; std::string test_time; @@ -1058,6 +1075,7 @@ struct test { embeddings = inst.embeddings; repack = inst.repack; fmoe = inst.fmoe; + use_thp = inst.use_thp; n_prompt = inst.n_prompt; n_gen = inst.n_gen; test_kind = inst.test_kind; @@ -1148,7 +1166,7 @@ struct test { "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser", - "tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", + "tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "use_thp", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts", "test", @@ -1169,7 +1187,7 @@ struct test { } if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || - field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || + field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || field == "use_thp" || field == "fused_moe") { return BOOL; } @@ -1211,7 +1229,8 @@ struct test { std::to_string(n_gpu_layers), split_mode_str(split_mode), std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser), - tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(fmoe), + tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), + std::to_string(repack), std::to_string(fmoe), std::to_string(use_thp), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), std::to_string(avg_ts()), std::to_string(stdev_ts()), @@ -1389,6 +1408,9 @@ struct markdown_printer : public printer { if (field == "repack") { return 3; } + if (field == "use_thp") { + return 3; + } if (field == "fused_moe") { return 4; } @@ -1435,6 +1457,9 @@ struct markdown_printer : public printer { if (field == "repack") { return "rtr"; } + if (field == "use_thp") { + return "thp"; + } if (field == "fused_moe") { return "fmoe"; } @@ -1505,6 +1530,9 @@ struct markdown_printer : public printer { if (params.repack != cmd_params_defaults.repack) { fields.emplace_back("repack"); } + if (params.use_thp != cmd_params_defaults.use_thp) { + fields.emplace_back("use_thp"); + } if (params.fmoe != cmd_params_defaults.fmoe) { fields.emplace_back("fused_moe"); } diff --git a/include/llama.h b/include/llama.h index 1b9c47e9..38822dbe 100644 --- a/include/llama.h +++ b/include/llama.h @@ -345,6 +345,7 @@ extern "C" { bool use_mlock; // force system to keep model in RAM bool check_tensors; // validate model tensor data bool repack_tensors;// repack if available + bool use_thp; // uase transparent huge pages (linux only) }; // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations diff --git a/src/llama.cpp b/src/llama.cpp index 54e0bb01..395b2879 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1827,6 +1827,7 @@ using llama_files = std::vector>; struct llama_mmap { void * addr; size_t size; + size_t mapped_page_size = 0; llama_mmap(const llama_mmap &) = delete; @@ -1836,7 +1837,7 @@ struct llama_mmap { // list of mapped fragments (first_offset, last_offset) std::vector> mapped_fragments; - llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */, bool numa = false) { + llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */, bool numa = false, [[maybe_unused]] bool use_thp = false) { size = file->size; int fd = fileno(file->fp); int flags = MAP_SHARED; @@ -1849,6 +1850,29 @@ struct llama_mmap { strerror(errno)); } if (prefetch) { flags |= MAP_POPULATE; } + if (use_thp) { + size_t huge = get_default_huge_page_size(); + auto size = huge*((file->size + huge - 1)/huge); + addr = mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS | MAP_HUGETLB, -1, 0); + if (addr != MAP_FAILED) { + printf("%s: using THP with page size %zu MiB ", __func__, huge/(1024*1024)); + fflush(stdout); + size_t tot = 0; + while (tot < file->size) { + auto n_read = pread(fd, static_cast(addr) + tot, file->size - tot, tot); + if (n_read < 0) throw std::runtime_error(format("Reading into mapped huge pages failed at %zu (%s)", tot, strerror(errno))); + printf("."); fflush(stdout); + tot += n_read; + } + printf(" done\n"); + mapped_fragments.emplace_back(0, file->size); + mapped_page_size = huge; + return; + } + else { + fprintf(stderr, "%s: mmap with huge page size %zu MiB failed (%s)\n", __func__, huge/(1024*1024), strerror(errno)); + } + } #endif addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0); if (addr == MAP_FAILED) { // NOLINT @@ -1893,7 +1917,7 @@ struct llama_mmap { void unmap_fragment(size_t first, size_t last) { // note: this function must not be called multiple times with overlapping ranges // otherwise, there is a risk of invalidating addresses that have been repurposed for other mappings - int page_size = sysconf(_SC_PAGESIZE); + int page_size = mapped_page_size > 0 ? mapped_page_size : sysconf(_SC_PAGESIZE); align_range(&first, &last, page_size); size_t len = last - first; @@ -1935,6 +1959,28 @@ struct llama_mmap { mapped_fragments = std::move(new_mapped_fragments); } +#ifdef __linux__ + static int get_default_huge_page_size() { + int pg_size = 2048; + std::ifstream in("/proc/meminfo"); + if (in) { + std::string line; + while (true) { + std::getline(in, line); + if (in.fail()) break; + if (auto pos = line.find("Hugepagesize:"); pos != std::string::npos) { + std::istringstream str(line.data() + pos + 13); + int aux; + str >> aux; + if (!str.fail()) pg_size = aux; + break; + } + } + } + return pg_size * 1024; + } +#endif + ~llama_mmap() { for (const auto & frag : mapped_fragments) { if (munmap((char *) addr + frag.first, frag.second - frag.first)) { @@ -1945,7 +1991,7 @@ struct llama_mmap { #elif defined(_WIN32) static constexpr bool SUPPORTED = true; - llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1, bool numa = false) { + llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1, bool numa = false, [[maybe_unused]] bool use_thp = false) { GGML_UNUSED(numa); size = file->size; @@ -2007,10 +2053,11 @@ struct llama_mmap { #else static constexpr bool SUPPORTED = false; - llama_mmap(struct llama_file * file, size_t prefetch = -1, bool numa = false) { + llama_mmap(struct llama_file * file, size_t prefetch = -1, bool numa = false, bool use_thp = false) { GGML_UNUSED(file); GGML_UNUSED(prefetch); GGML_UNUSED(numa); + GGML_UNUSED(use_thp); throw std::runtime_error("mmap not supported"); } @@ -3842,6 +3889,7 @@ struct llama_model_loader { bool use_mmap = false; bool check_tensors; bool repack_tensors = false; + bool use_thp = false; llama_files files; llama_ftype ftype; @@ -3876,7 +3924,7 @@ struct llama_model_loader { std::string arch_name; LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); - llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, + llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, bool use_thp, const llama_model_kv_override * param_overrides_p, const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) { int trace = 0; @@ -4140,6 +4188,7 @@ struct llama_model_loader { this->use_mmap = use_mmap; this->check_tensors = check_tensors; this->repack_tensors = repack_tensors; + this->use_thp = use_thp; } ~llama_model_loader() { @@ -4453,12 +4502,12 @@ struct llama_model_loader { } } - void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr) { + void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr, bool use_thp = false) { if (use_mmap) { mappings.reserve(files.size()); mmaps_used.reserve(files.size()); for (const auto & file : files) { - std::unique_ptr mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, ggml_is_numa())); + std::unique_ptr mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, ggml_is_numa(), use_thp)); mmaps_used.emplace_back(mapping->size, 0); if (mlock_mmaps) { std::unique_ptr mlock_mmap(new llama_mlock()); @@ -8077,7 +8126,7 @@ static bool llm_load_tensors( ml.done_getting_tensors(); - ml.init_mappings(true, use_mlock ? &model.mlock_mmaps : nullptr); + ml.init_mappings(true, use_mlock ? &model.mlock_mmaps : nullptr, ml.use_thp); model.mappings.reserve(ml.mappings.size()); // create the backend buffers @@ -8410,7 +8459,7 @@ static bool llm_load_tensors( static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) { try { llama_model_loader ml(fname, params.use_mmap, params.check_tensors, - params.repack_tensors, params.kv_overrides, params.tensor_buft_overrides); + params.repack_tensors, params.use_thp, params.kv_overrides, params.tensor_buft_overrides); model.hparams.vocab_only = params.vocab_only; @@ -17494,7 +17543,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s auto v = (std::vector*)params->kv_overrides; kv_overrides = v->data(); } - llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, /* repack_tensors */ false, kv_overrides, nullptr); + llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, /* repack_tensors */ false, /* use_thp */ false, kv_overrides, nullptr); ml.init_mappings(false); // no prefetching llama_model model; @@ -18318,6 +18367,7 @@ struct llama_model_params llama_model_default_params() { /*.use_mlock =*/ false, /*.check_tensors =*/ false, /*.repack_tensors =*/ false, + /*.use_thp =*/ false, }; #ifdef GGML_USE_METAL From 5a4855e61c05b0c54ecad3f4155074d8f344b6f6 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 23 Mar 2025 07:28:21 +0100 Subject: [PATCH 056/132] Attempt to improve FlashMLA on the CPU (#277) * Fix it for nth > rk2 * Handle rk2%nth_k != 0 * Cleanup --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml.c | 15 ++++----- ggml/src/iqk/iqk_flash_attn.cpp | 60 ++++++++++++++++++++++----------- 2 files changed, 48 insertions(+), 27 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index a2bdc156..036bd8a8 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -21771,15 +21771,14 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa if (gcd_k > 1) { int nth_k = n_tasks/gcd_k; int rk2 = q->ne[2]/k->ne[2]; - if (rk2%nth_k == 0) { - size_t size = (Dv + 16)*rk2/nth_k*sizeof(float)*n_tasks; - if (ggml_is_quantized(k->type)) { - enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type; - size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]); - size += q->ne[2]*row_size; - } - cur = MAX(cur, size); + int nq_per_thread = (rk2 + nth_k - 1)/nth_k; + size_t size = (Dv + 16)*nq_per_thread*sizeof(float)*n_tasks; + if (ggml_is_quantized(k->type)) { + enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type; + size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]); + size += q->ne[2]*row_size; } + cur = MAX(cur, size); } } #endif diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp index fecd818b..2c6da2b9 100644 --- a/ggml/src/iqk/iqk_flash_attn.cpp +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -64,40 +64,63 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, int gcd_k = simple_gcd(nstep_k, nth); if (gcd_k >= 1) { int nth_k = nth/gcd_k; - if (rk2%nth_k == 0) { - int ith_k = ith%gcd_k; - int ith_q = ith/gcd_k; + int ith_k = ith%gcd_k; + int ith_q = ith/gcd_k; + int nq_per_thread = (rk2 + nth_k - 1)/nth_k; + if (nq_per_thread > 1) { + int ith_mid = nth_k; + int nq_this_thread = nq_per_thread; + if (nq_per_thread*nth_k > rk2) { + ith_mid = rk2 - nth_k*(nq_per_thread - 1); + if (ith_q >= ith_mid) --nq_this_thread; + } + int j_mid = ith_mid*nq_per_thread; + auto work = (char *)work_buffer; + auto size_thread = (Dv + 16)*nq_per_thread*sizeof(float); + auto result_buffer = work; + auto kth = (const char *)k + ith_k*(nek1/gcd_k)*stride_k; auto vth = (const char *)v + ith_k*(nek1/gcd_k)*stride_v; - auto qth = (const char *)q + ith_q*(rk2/nth_k)*nbq2; + auto q_offset = ith_q < ith_mid ? ith_q*nq_per_thread*nbq2 : (ith_mid*nq_per_thread + (ith_q - ith_mid)*nq_this_thread)*nbq2; + auto qth = (const char *)q + q_offset; auto mth = (const char *)mask + ith_k*(nek1/gcd_k)*sizeof(uint16_t); // we don't have ggml_half available here - auto work = (char *)work_buffer; - // Each thread will produce a result of size Dv*(rk2/nth_k)*sizeof(float) - // In addition, we need M, S for the rk2/nth_k rows the thread is processing - // => (Dv + 2)*rk2/nth_k*sizeof(float). We use (Dv + 16) instead to make sure threads are not + // Each thread will produce a result of size Dv*nq_this_thread*sizeof(float) + // In addition, we need M, S for the nq_this_thread rows the thread is processing + // => (Dv + 2)*nq_per_thread*sizeof(float). We use (Dv + 16) instead to make sure threads are not // writing onto the same cache line. - auto size_thread = (Dv + 16)*rk2/nth_k*sizeof(float); - auto result_buffer = work; auto work_this_thread = (float *)(result_buffer + ith*size_thread); if (!iqk_flash_attn_impl(int_type_k, int_type_v, - Dk, Dv, rk2/nth_k, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv, + Dk, Dv, nq_this_thread, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv, (const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth, scale, softcap, - work_this_thread, work_this_thread + (Dv+0)*rk2/nth_k, work_this_thread + (Dv+1)*rk2/nth_k)) return false; + work_this_thread, work_this_thread + (Dv+0)*nq_this_thread, work_this_thread + (Dv+1)*nq_this_thread)) return false; barrier(barrier_data); + // There are nek1/gcd_k contributions for each j that we need to sum up + // Thread i computed k/v (i%gcd_k)*(nek1/gcd_k) for j (i/gcd_k)*(rk2/nth_k)...((i/gcd_k)+1)*(rk2/nth_k) and results at offset i*size_thread + // TODO: simdify this + // TODO: if nth > rk2, have threads process portions of the rows instead of entire rows as it is now for (int j = ith; j < rk2; j += nth) { auto Racc = qkv + j*nb1/sizeof(float); float M = -INFINITY, S = 0; - int jth_q = j/(rk2/nth_k); - int jj = j%(rk2/nth_k); - for (int j1 = 0; j1 < rk2/nth_k; ++j1) { - auto R = (const float *)(result_buffer + (jth_q*(rk2/nth_k) + j1)*size_thread); - auto Mj = R + Dv*rk2/nth_k; - auto Sj = Mj + rk2/nth_k; + int jth_first, jj, nq_this_j; + if (j < j_mid) { + jth_first = j/nq_per_thread; + jj = j%nq_per_thread; + nq_this_j = nq_per_thread; + } else { + jth_first = ith_mid + (j - j_mid)/(nq_per_thread-1); + jj = (j - j_mid)%(nq_per_thread-1); + nq_this_j = nq_per_thread - 1; + } + jth_first *= gcd_k; + for (int jth = jth_first; jth < jth_first + gcd_k; ++jth) { + auto R = (const float *)(result_buffer + jth*size_thread); + auto Mj = R + Dv*nq_this_j; + auto Sj = Mj + nq_this_j; R += jj*Dv; if (Mj[jj] == -INFINITY) continue; if (Mj[jj] > M) { @@ -120,7 +143,6 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, for (int i = 0; i < Dv; ++i) Racc[i] *= norm; } return true; - } } } From f9307d79071c2a1e8efe10ecb1e1304bf77c021a Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 23 Mar 2025 17:10:52 +0100 Subject: [PATCH 057/132] Improve DeepSeek batched processing speed (#282) * Improve DeepSeek batched processing speed * Revert the commented out section in iqk_mul_mat.cpp It does have some benefit at long contexts. --------- Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_mul_mat.cpp | 16 ++++++++++++++-- src/llama.cpp | 2 +- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 8b6d6b1c..4d29e2f0 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -17265,13 +17265,25 @@ template inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) { - if (nq1 % 8 == 0) { + if (nq1 >= 8) { FlashAttn<576, 512, 8, step_k> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); - } else { + } + else if (nq1 >= 4) { + FlashAttn<576, 512, 4, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + } + else { FlashAttn<576, 512, 1, step_k> fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); } + //if (nq1 % 8 == 0) { + // FlashAttn<576, 512, 8, step_k> fa(scale, softcap); + // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + //} else { + // FlashAttn<576, 512, 1, step_k> fa(scale, softcap); + // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + //} } template diff --git a/src/llama.cpp b/src/llama.cpp index 395b2879..24737265 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -13896,7 +13896,7 @@ struct llm_build_context { // whether to use n_tokens as the matrix dimension during multiplication or n_head // n_tokens is higher during prompt processing, this allows to optimize for this case - bool pp_opt = n_tokens > n_head; + bool pp_opt = n_tokens >= 128; // Is it a fixed constant or is it somehow relared to n_head? original: n_tokens > n_head; for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; From 98a264a2ea21761322847ac562f58d986ef6c512 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 25 Mar 2025 07:47:10 +0100 Subject: [PATCH 058/132] CUDA: better MoE implementation (#283) * Make fused MoE reproducible As a bonus, peak performance at pp2048 with u_batch = 2048 is ~8% better. * Slightly better * Also do it for non-fused mul_mat_id --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda.cu | 171 +++++++++++++++++++++--------------------- 1 file changed, 86 insertions(+), 85 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index e488c9bc..5f4731c8 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2164,35 +2164,19 @@ struct mmid_row_mapping { int32_t i2; }; -static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous, - int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping, - const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0, - int64_t ne11, int64_t ne10, - size_t nb11, size_t nb12) { - int32_t iid1 = blockIdx.x; - int32_t id = blockIdx.y; +static __global__ void k_copy_src_to_contiguous(const char * __restrict__ src_original, char * __restrict__ src_contiguous, + const mmid_row_mapping * __restrict__ row_mapping, + int64_t ne10, int64_t ne11, size_t nb11, size_t nb12) { + int32_t i = blockIdx.x; - const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0); + const int32_t i11 = row_mapping[i].i1 % ne11; + const int32_t i12 = row_mapping[i].i2; - if (row_id_i != i02) { - return; - } + float * src_row_contiguous = (float *)(src_contiguous + i*nb11); + const float * src_row_original = (const float *)(src_original + i11*nb11 + i12*nb12); - const int64_t i11 = id % ne11; - const int64_t i12 = iid1; - - __shared__ int src1_row; - if (threadIdx.x == 0) { - src1_row = atomicAdd(cur_src1_row, 1); - row_mapping[src1_row] = {id, iid1}; - } - __syncthreads(); - - const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12); - float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11); - - for (int i = threadIdx.x; i < ne10; i += blockDim.x) { - src1_row_contiguous[i] = src1_row_original[i]; + for (int j = threadIdx.x; j < ne10; j += blockDim.x) { + src_row_contiguous[j] = src_row_original[j]; } } @@ -2213,6 +2197,51 @@ static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_origin } } +static inline void prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n_as, int64_t n_ids, + const ggml_tensor * ids, std::vector& moe_counts, std::vector& cum_moe_counts, + ggml_cuda_pool_alloc& dev_row_mapping) { + + GGML_ASSERT(moe_counts.empty() && cum_moe_counts.empty()); + + auto stream = ctx.stream(); + + std::vector ids_host(ggml_nbytes(ids)); + const char * ids_dev = (const char *) ids->data; + CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); + //CUDA_CHECK(cudaStreamSynchronize(stream)); + + std::vector rmapping(ids->ne[1]*n_ids); + moe_counts.resize(n_as, 0); + cum_moe_counts.resize(n_as + 1); + + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + if (row_id_i >= 0 && row_id_i < n_as) ++moe_counts[row_id_i]; + } + } + cum_moe_counts[0] = 0; + for (int i = 0; i < (int)n_as; ++i) { + cum_moe_counts[i+1] = cum_moe_counts[i] + moe_counts[i]; + } + + dev_row_mapping.alloc(cum_moe_counts[n_as]); + + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + if (row_id_i >= 0 && row_id_i < n_as) { + rmapping[cum_moe_counts[row_id_i]++] = {(int)id, (int)iid1}; + } + } + } + + for (int i = 0; i < (int)n_as; ++i) cum_moe_counts[i] -= moe_counts[i]; + + CUDA_CHECK(cudaMemcpyAsync(dev_row_mapping.get(), rmapping.data(), cum_moe_counts[n_as]*sizeof(mmid_row_mapping), cudaMemcpyHostToDevice, stream)); + +} + static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -2273,10 +2302,10 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * const int64_t n_as = ne02; const int64_t n_ids = ids->ne[0]; - std::vector ids_host(ggml_nbytes(ids)); - const char * ids_dev = (const char *) ids->data; - CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); - CUDA_CHECK(cudaStreamSynchronize(stream)); + //std::vector ids_host(ggml_nbytes(ids)); + //const char * ids_dev = (const char *) ids->data; + //CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); + //CUDA_CHECK(cudaStreamSynchronize(stream)); ggml_tensor src0_row = *src0; ggml_tensor src1_row = *src1; @@ -2303,6 +2332,9 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst_row.nb[3] = nb1; if (ne12 == 1) { + std::vector ids_host(ggml_nbytes(ids)); + const char * ids_dev = (const char *) ids->data; + CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { for (int64_t id = 0; id < n_ids; id++) { const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); @@ -2324,6 +2356,11 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } } } else { + + ggml_cuda_pool_alloc dev_row_mapping(ctx.pool()); + std::vector moe_counts, cum_moe_counts; + prepare_row_mappigs(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping); + ggml_cuda_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); ggml_cuda_pool_alloc dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); @@ -2331,37 +2368,20 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst_row.data = dst_contiguous.get(); for (int64_t i02 = 0; i02 < n_as; i02++) { - int64_t num_src1_rows = 0; - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - - if (row_id_i != i02) { - continue; - } - - num_src1_rows++; - } - } + int64_t num_src1_rows = moe_counts[i02]; if (num_src1_rows == 0) { continue; } - ggml_cuda_pool_alloc dev_cur_src1_row(ctx.pool(), 1); - ggml_cuda_pool_alloc dev_row_mapping(ctx.pool(), num_src1_rows); - CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream)); + size_t mapping_offset = cum_moe_counts[i02]; { dim3 block_dims(std::min((unsigned int)ne10, 768u)); - dim3 grid_dims(ids->ne[1], n_ids); - k_copy_src1_to_contiguous<<>>( - src1_original, src1_contiguous.get(), - dev_cur_src1_row.get(), dev_row_mapping.get(), - ids_dev, i02, ids->nb[1], ids->nb[0], - ne11, ne10, - nb11, nb12); + dim3 grid_dims(num_src1_rows); + k_copy_src_to_contiguous<<>>( + src1_original, src1_contiguous.get(), dev_row_mapping.get() + mapping_offset, ne10, ne11, nb11, nb12); CUDA_CHECK(cudaGetLastError()); } @@ -2387,7 +2407,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dim3 grid_dims(num_src1_rows); k_copy_dst_from_contiguous<<>>( dst_original, dst_contiguous.get(), - dev_row_mapping.get(), + dev_row_mapping.get() + mapping_offset, ne0, nb1, nb2); CUDA_CHECK(cudaGetLastError()); @@ -2642,41 +2662,22 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor bool first = false; //true; + ggml_cuda_pool_alloc dev_row_mapping(ctx.pool()); + std::vector moe_counts, cum_moe_counts; + + prepare_row_mappigs(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping); + for (int64_t i02 = 0; i02 < n_as; i02++) { - int64_t num_src1_rows = 0; + int64_t num_src1_rows = moe_counts[i02]; - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - - if (row_id_i < 0 || row_id_i >= n_as) continue; - //GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); - - if (row_id_i != i02) { - continue; - } - - num_src1_rows++; - } - } - - if (num_src1_rows == 0) { - continue; - } - - ggml_cuda_pool_alloc dev_cur_src1_row(ctx.pool(), 1); - ggml_cuda_pool_alloc dev_row_mapping(ctx.pool(), num_src1_rows); - CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream)); + if (num_src1_rows == 0) continue; + size_t mapping_offset = cum_moe_counts[i02]; { dim3 block_dims(std::min((unsigned int)ne10, 768u)); - dim3 grid_dims(ids->ne[1], n_ids); - k_copy_src1_to_contiguous<<>>( - src1_original, src1_contiguous.get(), - dev_cur_src1_row.get(), dev_row_mapping.get(), - ids_dev, i02, ids->nb[1], ids->nb[0], - ne11, ne10, - nb11, nb12); + dim3 grid_dims(num_src1_rows); + k_copy_src_to_contiguous<<>>( + src1_original, src1_contiguous.get(), dev_row_mapping.get() + mapping_offset, ne10, ne11, nb11, nb12); CUDA_CHECK(cudaGetLastError()); } @@ -2733,7 +2734,7 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor dim3 grid_dims(num_src1_rows); k_copy_dst_from_contiguous<<>>( (char *)next->data, final_dst_contiguous.get(), - dev_row_mapping.get(), + dev_row_mapping.get() + mapping_offset, next->ne[0], next->nb[1], next->nb[2]); CUDA_CHECK(cudaGetLastError()); @@ -2745,7 +2746,7 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor dim3 grid_dims(num_src1_rows); k_copy_dst_from_contiguous<<>>( dst_original, dst_gate_contiguous.get(), - dev_row_mapping.get(), + dev_row_mapping.get() + mapping_offset, ne0, nb1, nb2); CUDA_CHECK(cudaGetLastError()); From 279b7d33950c7f2a1de29231179b600294784ade Mon Sep 17 00:00:00 2001 From: saood06 Date: Tue, 25 Mar 2025 10:14:44 -0500 Subject: [PATCH 059/132] Update sweep bench (depracating .jsonl support) (#289) * Update sweep bench (depracating .jsonl support) * Fix README.md --- examples/sweep-bench/README.md | 1 + examples/sweep-bench/sweep-bench-plot.py | 66 +++++++++++++++--------- examples/sweep-bench/sweep-bench.cpp | 28 +++++----- 3 files changed, 57 insertions(+), 38 deletions(-) diff --git a/examples/sweep-bench/README.md b/examples/sweep-bench/README.md index 608fd104..d92740de 100644 --- a/examples/sweep-bench/README.md +++ b/examples/sweep-bench/README.md @@ -7,6 +7,7 @@ in each ubatch-sized window. Only a single token sequence is used. The benchmark steps are: for each ubatch-sized window in context: + 1. generate ubatch/4 tokens (not the whole window to save some time) 2. measure generation performance 3. remove generated tokens from KV cache diff --git a/examples/sweep-bench/sweep-bench-plot.py b/examples/sweep-bench/sweep-bench-plot.py index 0d3c81db..481a604c 100755 --- a/examples/sweep-bench/sweep-bench-plot.py +++ b/examples/sweep-bench/sweep-bench-plot.py @@ -9,27 +9,54 @@ args = parser.parse_args() df = None -for jsonl_file in args.file: - # Read JSONL file into DataFrame - df_part = pd.read_json(jsonl_file, lines=True) - df_part['label'] = jsonl_file - if df is None: - df = df_part - else: - df = pd.concat([df, df_part]) +#for jsonl_file in args.file: +# # Read JSONL file into DataFrame +# df_part = pd.read_json(jsonl_file, lines=True) +# df_part['label'] = jsonl_file +# if df is None: +# df = df_part +# else: +# df = pd.concat([df, df_part]) +# -# Group by model and n_kv, calculate mean and std for both speed metrics + + +for md_file in args.file: + # Read markdown table file into DataFrame + df_part = pd.read_csv(md_file, sep=r'\s*\|\s*', engine='python', + header=0, skiprows=[1]) + + # Clean up columns (remove empty columns from markdown formatting) + df_part = df_part.iloc[:, 1:-1] + df_part.columns = [col.strip() for col in df_part.columns] + + # Rename columns to match expected names + df_part = df_part.rename(columns={ + 'N_KV': 'n_kv', + 'S_PP t/s': 'speed_pp', + 'S_TG t/s': 'speed_tg' + }) + + # Convert to numeric types + df_part['n_kv'] = pd.to_numeric(df_part['n_kv']) + df_part['speed_pp'] = pd.to_numeric(df_part['speed_pp']) + df_part['speed_tg'] = pd.to_numeric(df_part['speed_tg']) + + # Add label and append to main DataFrame + df_part['label'] = md_file + df = pd.concat([df, df_part]) if df is not None else df_part + +# Group by label and n_kv, calculate mean and std for both speed metrics df_grouped = df.groupby(['label', 'n_kv']).agg({ 'speed_pp': ['mean', 'std'], 'speed_tg': ['mean', 'std'] }).reset_index() # Flatten multi-index columns -df_grouped.columns = ['label', 'n_kv', 'speed_pp_mean', 'speed_pp_std', +df_grouped.columns = ['label', 'n_kv', 'speed_pp_mean', 'speed_pp_std', 'speed_tg_mean', 'speed_tg_std'] # Replace NaN with 0 (std for a single sample is NaN) - df_grouped['speed_pp_std'] = df_grouped['speed_pp_std'].fillna(0) df_grouped['speed_tg_std'] = df_grouped['speed_tg_std'].fillna(0) @@ -45,25 +72,20 @@ colors = plt.cm.rainbow(np.linspace(0, 1, len(labels))) # Create prompt processing plot plt.figure(figsize=(10, 6)) ax1 = plt.gca() - plt.grid() - ax1.set_xticks(x_ticks) # Plot each label's data for label, color in zip(labels, colors): label_data = df_grouped[df_grouped['label'] == label].sort_values('n_kv') - - # Plot prompt processing - pp = ax1.errorbar(label_data['n_kv'], label_data['speed_pp_mean'], - yerr=label_data['speed_pp_std'], color=color, + pp = ax1.errorbar(label_data['n_kv'], label_data['speed_pp_mean'], + yerr=label_data['speed_pp_std'], color=color, marker='o', linestyle='-', label=label) - + # Add labels and title ax1.set_xlabel('Context Length (tokens)') ax1.set_ylabel('Prompt Processing Rate (t/s)') plt.title('Prompt Processing Performance Comparison') - ax1.legend(loc='upper right') # Adjust layout and save @@ -74,24 +96,20 @@ plt.close() # Create token generation plot plt.figure(figsize=(10, 6)) ax1 = plt.gca() - plt.grid() ax1.set_xticks(x_ticks) # Plot each model's data for label, color in zip(labels, colors): label_data = df_grouped[df_grouped['label'] == label].sort_values('n_kv') - - # Plot token generation tg = ax1.errorbar(label_data['n_kv'], label_data['speed_tg_mean'], - yerr=label_data['speed_tg_std'], color=color, + yerr=label_data['speed_tg_std'], color=color, marker='s', linestyle='-', label=label) # Add labels and title ax1.set_xlabel('Context Length (n_kv)') ax1.set_ylabel('Token Generation Rate (t/s)') plt.title('Token Generation Performance Comparison') - ax1.legend(loc='upper right') # Adjust layout and save diff --git a/examples/sweep-bench/sweep-bench.cpp b/examples/sweep-bench/sweep-bench.cpp index 4e594de5..27510687 100644 --- a/examples/sweep-bench/sweep-bench.cpp +++ b/examples/sweep-bench/sweep-bench.cpp @@ -18,9 +18,9 @@ #include static void print_usage(int, char ** argv) { - LOG("\nexample usage:\n"); - LOG("\n %s -m model.gguf -c 8192 -b 2048 -ub 512\n", argv[0]); - LOG("\n"); + LOG_TEE("\nexample usage:\n"); + LOG_TEE("\n %s -m model.gguf -c 8192 -b 2048 -ub 512\n", argv[0]); + LOG_TEE("\n"); } int main(int argc, char ** argv) { @@ -83,7 +83,7 @@ int main(int argc, char ** argv) { const int ret = llama_decode(ctx, batch_view); if (ret != 0) { - LOG("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); + LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); return false; } @@ -97,11 +97,11 @@ int main(int argc, char ** argv) { const unsigned int tg = params.n_ubatch / 4; if (!params.sweep_bench_output_jsonl) { - LOG("\n"); - LOG("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch); - LOG("\n"); - LOG("|%6s | %6s | %6s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s"); - LOG("|%6s-|-%6s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|\n", "------", "------", "------", "--------", "--------", "--------", "--------"); + LOG_TEE("\n"); + LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch); + LOG_TEE("\n"); + LOG_TEE("|%6s | %6s | %6s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s"); + LOG_TEE("|%6s-|-%6s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|\n", "------", "------", "------", "--------", "--------", "--------", "--------"); } llama_batch batch = llama_batch_init(n_kv_max, 0, 1); @@ -111,7 +111,7 @@ int main(int argc, char ** argv) { llama_batch_add(batch, bos, 0, { 0 }, false); if (!decode_helper(ctx, batch, ctx_params.n_batch)) { - LOG("%s: llama_decode() failed\n", __func__); + LOG_TEE("%s: llama_decode() failed\n", __func__); return 1; } } @@ -131,7 +131,7 @@ int main(int argc, char ** argv) { llama_batch_add(batch, std::rand() % n_vocab, n_kv + i, { 0 }, true); if (!decode_helper(ctx, batch, ctx_params.n_batch)) { - LOG("%s: llama_decode() failed\n", __func__); + LOG_TEE("%s: llama_decode() failed\n", __func__); return 1; } } @@ -153,7 +153,7 @@ int main(int argc, char ** argv) { const auto t_pp_start = ggml_time_us(); if (!decode_helper(ctx, batch, ctx_params.n_batch)) { - LOG("%s: llama_decode() failed\n", __func__); + LOG_TEE("%s: llama_decode() failed\n", __func__); return 1; } @@ -167,14 +167,14 @@ int main(int argc, char ** argv) { const float speed_tg = tg / t_tg; if(params.sweep_bench_output_jsonl) { - LOG( + LOG_TEE( "{\"n_kv_max\": %d, \"n_batch\": %d, \"n_ubatch\": %d, \"flash_attn\": %d, \"n_gpu_layers\": %d, \"n_threads\": %u, \"n_threads_batch\": %u, " "\"pp\": %d, \"tg\": %d, \"n_kv\": %d, \"t_pp\": %f, \"speed_pp\": %f, \"t_tg\": %f, \"speed_tg\": %f }\n", n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch, pp, tg, n_kv, t_pp, speed_pp, t_tg, speed_tg ); } else { - LOG("|%6d | %6d | %6d | %8.3f | %8.2f | %8.3f | %8.2f |\n", pp, tg, n_kv, t_pp, speed_pp, t_tg, speed_tg); + LOG_TEE("|%6d | %6d | %6d | %8.3f | %8.2f | %8.3f | %8.2f |\n", pp, tg, n_kv, t_pp, speed_pp, t_tg, speed_tg); } } From a22250df93fd833a6cb7f310b159ad1b54e4d582 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 25 Mar 2025 16:31:17 +0100 Subject: [PATCH 060/132] llama-bench: enable having different number of threads for tg and pp (#284) * llama-bench: enable having different number of threads for tg and pp * Add -tgb to usage --------- Co-authored-by: Iwan Kawrakow --- examples/llama-bench/llama-bench.cpp | 44 +++++++++++++++++++++------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 24d2e185..ae79ad02 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -41,6 +41,12 @@ static uint64_t get_time_ns() { return std::chrono::nanoseconds(clock::now().time_since_epoch()).count(); } +template +std::ostream& operator<<(std::ostream& str, const std::pair& item) { + str << '{' << item.first << ", " << item.second << '}'; + return str; +} + template static std::string join(const std::vector & values, const std::string & delim) { std::ostringstream str; @@ -228,7 +234,7 @@ struct cmd_params { std::vector n_ubatch; std::vector type_k; std::vector type_v; - std::vector n_threads; + std::vector> n_threads; std::vector n_gpu_layers; std::vector rpc_servers; std::vector split_mode; @@ -263,7 +269,7 @@ static const cmd_params cmd_params_defaults = { /* n_ubatch */ {512}, /* type_k */ {GGML_TYPE_F16}, /* type_v */ {GGML_TYPE_F16}, - /* n_threads */ {cpu_get_num_math()}, + /* n_threads */ {{cpu_get_num_math(), cpu_get_num_math()}}, /* n_gpu_layers */ {99}, /* rpc_servers */ {""}, /* split_mode */ {LLAMA_SPLIT_MODE_LAYER}, @@ -303,6 +309,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -ctk, --cache-type-k (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str()); printf(" -ctv, --cache-type-v (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str()); printf(" -t, --threads (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str()); + printf(" -tgb, --threads-gen-batch (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str()); printf(" -ngl, --n-gpu-layers (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str()); printf(" -rpc, --rpc (default: %s)\n", join(cmd_params_defaults.rpc_servers, ",").c_str()); printf(" -sm, --split-mode (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str()); @@ -538,7 +545,23 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } auto p = string_split(argv[i], split_delim); - params.n_threads.insert(params.n_threads.end(), p.begin(), p.end()); + params.n_threads.reserve(params.n_threads.size() + p.size()); + for (auto t : p) params.n_threads.push_back({t, t}); + //params.n_threads.insert(params.n_threads.end(), p.begin(), p.end()); + } else if (arg == "-tgb" || arg == "--threads-gen-batch") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto ps = string_split(argv[i], ';'); + for (auto& s : ps) { + auto p = string_split(s.c_str(), ','); + if (p.size() != 2) { + invalid_param = true; + break; + } + params.n_threads.push_back({p[0], p[1]}); + } } else if (arg == "-ngl" || arg == "--n-gpu-layers") { if (++i >= argc) { invalid_param = true; @@ -775,7 +798,7 @@ struct cmd_params_instance { int n_ubatch; ggml_type type_k; ggml_type type_v; - int n_threads; + std::pair n_threads; int n_gpu_layers; std::string rpc_servers; llama_split_mode split_mode; @@ -1024,7 +1047,7 @@ struct test { uint64_t model_n_params; int n_batch; int n_ubatch; - int n_threads; + std::pair n_threads; bool has_rpc; ggml_type type_k; ggml_type type_v; @@ -1218,6 +1241,7 @@ struct test { str << ser.first << ',' << ser.second; return str.str(); }; + bool is_gen = n_gen > 0; std::vector values = { build_commit, std::to_string(build_number), std::to_string(cuda), std::to_string(vulkan), std::to_string(vulkan), @@ -1225,7 +1249,7 @@ struct test { cpu_info, gpu_info, model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params), std::to_string(n_batch), std::to_string(n_ubatch), - std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), + std::to_string(is_gen ? n_threads.first : n_threads.second), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser), @@ -1787,10 +1811,10 @@ int main(int argc, char ** argv) { if (params.warmup) { if (t.n_prompt > 0) { //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads); - test_prompt(ctx, 1, 0, t.n_batch, t.n_threads); + test_prompt(ctx, 1, 0, t.n_batch, t.n_threads.second); } if (t.n_gen > 0) { - test_gen(ctx, 1, 0, t.n_threads); + test_gen(ctx, 1, 0, t.n_threads.first); } } @@ -1800,11 +1824,11 @@ int main(int argc, char ** argv) { uint64_t t_start = get_time_ns(); if (t.n_prompt > 0) { - test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads); + test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads.second); } if (t.test_kind == TEST_KIND_GP) t_start = get_time_ns(); if (t.n_gen > 0) { - test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads); + test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads.first); } uint64_t t_ns = get_time_ns() - t_start; From d0b52076da0261f291b01f1ffa44884c8b2cdb1c Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 27 Mar 2025 05:49:16 +0100 Subject: [PATCH 061/132] Use bf16 instead of fp16 block scales for q8_1 (#292) * WIP - not working * q8_0 without bells and wistles works * It works for q8_0 * Use bf16 instead of f16,int16 * q4_0_r8 * q5_0_r4 * q6_0_r4 * Also q4_1 and q5_1 * q8_0_r8 on avx2 --------- Co-authored-by: Iwan Kawrakow --- ggml/include/ggml.h | 5 +- ggml/src/ggml-common.h | 14 + ggml/src/ggml.c | 39 ++- ggml/src/iqk/iqk_mul_mat.cpp | 477 ++++++++++++++++++---------------- ggml/src/iqk/iqk_quantize.cpp | 67 ++++- ggml/src/iqk/iqk_quantize.h | 1 + 6 files changed, 348 insertions(+), 255 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 91219d4a..7cc9100d 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -396,8 +396,9 @@ extern "C" { // GGML_TYPE_I2_S = 36, // - GGML_TYPE_Q8_0_X4 = 98, - GGML_TYPE_Q8_1_X4 = 99, + GGML_TYPE_Q8_0_X4 = 97, + GGML_TYPE_Q8_1_X4 = 98, + GGML_TYPE_Q8_2_X4 = 99, GGML_TYPE_Q6_0 = 133, GGML_TYPE_IQ1_BN = 134, GGML_TYPE_IQ2_BN = 135, diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 4308f0b9..59702e32 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -266,6 +266,20 @@ typedef struct { } block_q8_0x8; static_assert(sizeof(block_q8_0x8) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong q8_0x8 block size/padding"); +#define QK8_2 32 +typedef struct { + uint16_t d; + uint16_t s; + int8_t qs[QK8_2]; // quants +} block_q8_2; +static_assert(sizeof(block_q8_2) == sizeof(ggml_half) + sizeof(int16_t) + QK8_2, "wrong q8_2 block size/padding"); + +typedef struct { + uint16_t d[8]; + int8_t qs[4*QK8_2]; +} block_q8_2_x4; +static_assert(sizeof(block_q8_2_x4) == 4*sizeof(block_q8_2), "wrong q8_2_x4 block size/padding"); + // // Super-block quantization structures // diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 036bd8a8..25694fc7 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -717,7 +717,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_q4_0_q8_0, #if GGML_USE_IQK_MULMAT #if defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1_X4, + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif @@ -741,7 +741,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref, .vec_dot = ggml_vec_dot_q4_1_q8_1, #if GGML_USE_IQK_MULMAT - .vec_dot_type = GGML_TYPE_Q8_1_X4, + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_1, #endif @@ -789,7 +789,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_q5_0_q8_0, #if GGML_USE_IQK_MULMAT #if defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1_X4, + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif @@ -809,7 +809,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float_ref = (ggml_from_float_t) quantize_row_q5_1_ref, .vec_dot = ggml_vec_dot_q5_1_q8_1, #if GGML_USE_IQK_MULMAT - .vec_dot_type = GGML_TYPE_Q8_1_X4, + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_1, #endif @@ -827,7 +827,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_q6_0_q8_0, #if GGML_USE_IQK_MULMAT #if defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1_X4, + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif @@ -852,7 +852,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { // Remember: we cannot add 128 to the Q8 quants and use iblock sum in Q8_1 to subtract as we do on Zen4 for pure AVX2 // because there the result of the _mm256_maddubs_epi16() instruction may overflow the int16_t range // (and it gets satured if it does), leading to wrong results. - .vec_dot_type = GGML_TYPE_Q8_1_X4, + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif @@ -897,6 +897,16 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_Q8_2_X4] = { + .type_name = "q8_2_x4", + .blck_size = QK8_2, + .type_size = sizeof(block_q8_2), + .is_quantized = true, + .from_float = quantize_row_q8_2_x4, + .from_float_ref = quantize_row_q8_2_x4, + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_Q2_K] = { .type_name = "q2_K", .blck_size = QK_K, @@ -1272,7 +1282,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = ggml_vec_dot_iq4_nl_q8_0, #if GGML_USE_IQK_MULMAT #if defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1_X4, + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif @@ -1628,7 +1638,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_iq4_nl_r4_q8_0, #if GGML_USE_IQK_MULMAT #if defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1_X4, + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif @@ -1662,7 +1672,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_q4_0_r8_q8_0, #if GGML_USE_IQK_MULMAT #if defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1_X4, + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif @@ -1683,7 +1693,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_q8_0_r8_q8_0, #if GGML_USE_IQK_MULMAT #if defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1_X4, + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif @@ -1704,7 +1714,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_q5_0_r4_q8_0, #if GGML_USE_IQK_MULMAT #if defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1_X4, + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif @@ -1725,7 +1735,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = vec_dot_q6_0_r4_q8_0, #if GGML_USE_IQK_MULMAT #if defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1_X4, + .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_0_X4, #endif @@ -11647,6 +11657,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_Q8_1: case GGML_TYPE_Q8_0_X4: case GGML_TYPE_Q8_1_X4: + case GGML_TYPE_Q8_2_X4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: @@ -11815,6 +11826,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_Q8_1: case GGML_TYPE_Q8_0_X4: case GGML_TYPE_Q8_1_X4: + case GGML_TYPE_Q8_2_X4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: @@ -15690,6 +15702,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_Q8_1: case GGML_TYPE_Q8_0_X4: case GGML_TYPE_Q8_1_X4: + case GGML_TYPE_Q8_2_X4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: @@ -15997,6 +16010,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q8_1: case GGML_TYPE_Q8_0_X4: case GGML_TYPE_Q8_1_X4: + case GGML_TYPE_Q8_2_X4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: @@ -16627,6 +16641,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_Q8_1: case GGML_TYPE_Q8_0_X4: case GGML_TYPE_Q8_1_X4: + case GGML_TYPE_Q8_2_X4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 4d29e2f0..cf512ba5 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -3436,9 +3436,9 @@ static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const Da #ifdef HAVE_FANCY_SIMD template -static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_iq4_nl_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%8 == 0); - Q8 q8(info); + Q8 q8(info); auto m4 = _mm512_set1_epi8(0xf); auto values = load_iq4nl_values_512(); int nb = n / QK4_NL; @@ -3475,7 +3475,8 @@ static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const Data const block_iq4_nl_r4 * iq4h = (const block_iq4_nl_r4 *)((const char *)vx + (ix+4)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - _mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d))); + auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16); + _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(aux)); } for (int k = 0; k < 4; ++k) { auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]); @@ -3492,9 +3493,10 @@ static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const Data for (int iy = 0; iy < nrc_y; ++iy) { auto qy = (const block_q8_1 *)q8.y[iy]; auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); - auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)); + ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s; + auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d)); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]); } } for (int iy = 0; iy < nrc_y; ++iy) { @@ -3509,9 +3511,9 @@ static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const Data } #else template -static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_iq4_nl_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); - Q8 q8(info); + Q8 q8(info); auto m4 = _mm256_set1_epi8(0xf); auto m1 = _mm256_set1_epi16(1); auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values); @@ -3548,7 +3550,8 @@ static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const Data const block_iq4_nl_r4 * iq4 = (const block_iq4_nl_r4 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - _mm_storeu_ps(d8+4*iy, _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d))); + auto aux = _mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d)), 16); + _mm_storeu_ps(d8+4*iy, _mm_castsi128_ps(aux)); } for (int k = 0; k < 4; ++k) { auto scales = prepare(iq4[4*ib4+k]); @@ -3564,7 +3567,8 @@ static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const Data for (int iy = 0; iy < nrc_y; ++iy) { auto qy = (const block_q8_1 *)q8.y[iy]; auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + ggml_bf16_t d{qy[ib].d}; + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); } } @@ -3623,7 +3627,7 @@ inline __m256i accum_q4_0_quants(const __m256i * v, const int8_t * qs) { } template -static void mul_mat_q4_0_r8_q8_1_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_q4_0_r8_q8_2_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%8 == 0); Q8 q8(info); auto m4 = _mm256_set1_epi8(0xf); @@ -3637,7 +3641,7 @@ static void mul_mat_q4_0_r8_q8_1_avx2(int n, const void * vx, size_t bx, const D auto acc1 = _mm256_setzero_ps(); auto acc2 = _mm256_setzero_ps(); for (int ib4 = 0; ib4 < nb/4; ++ib4) { - helper.vec = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d)); + helper.vec = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d)), 16)); for (int k = 0; k < 4; ++k) { auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[4*ib4+k].d)); prepare_q4_0_quants_avx2(iq4[4*ib4+k].qs, v, m4); @@ -3652,9 +3656,10 @@ static void mul_mat_q4_0_r8_q8_1_avx2(int n, const void * vx, size_t bx, const D auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ib].d)); prepare_q4_0_quants_avx2(iq4[ib].qs, v, m4); auto sumi = accum_q4_0_quants(v, qy[ib].qs); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); acc1 = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc1); - acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc2); + acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(s)), acc2); } acc1 = _mm256_fmadd_ps(acc2, _mm256_set1_ps(-8.f), acc1); info.store(ix, 0, acc1); @@ -3672,7 +3677,7 @@ static void mul_mat_q4_0_r8_q8_1_avx2(int n, const void * vx, size_t bx, const D d4[k] = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[4*ib4+k].d)); } for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)); + auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16)); _mm256_storeu_ps(d8 + 8*iy, scales); auto m4 = _mm256_extractf128_ps(scales, 1); auto m8 = _mm256_set_m128(m4, m4); @@ -3700,9 +3705,10 @@ static void mul_mat_q4_0_r8_q8_1_avx2(int n, const void * vx, size_t bx, const D for (int iy = 0; iy < nrc_y; ++iy) { auto qy = (const block_q8_1 *)q8.y[iy]; auto sumi = accum_q4_0_quants(v, qy[ib].qs); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_BF16_TO_FP32(s)), acc[iy]); } } for (int iy = 0; iy < nrc_y; ++iy) { @@ -3977,9 +3983,9 @@ static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataI #ifdef HAVE_FANCY_SIMD template -static void mul_mat_q4_0_r8_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_q4_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { if constexpr (nrc_y == 1) { - mul_mat_q4_0_r8_q8_1_avx2<1>(n, vx, bx, info, nrc_x); + mul_mat_q4_0_r8_q8_2_avx2<1>(n, vx, bx, info, nrc_x); return; } GGML_ASSERT(nrc_x%16 == 0); @@ -4024,7 +4030,8 @@ static void mul_mat_q4_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn const block_iq4_nl_r8 * iq4h = (const block_iq4_nl_r8 *)((const char *)vx + (ix+8)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - _mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d))); + auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16); + _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(aux)); } for (int k = 0; k < 4; ++k) { auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]); @@ -4041,9 +4048,10 @@ static void mul_mat_q4_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn for (int iy = 0; iy < nrc_y; ++iy) { auto qy = (const block_q8_1 *)q8.y[iy]; auto sumi = dot(qy[ib].qs); - auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)); + ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; + auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d)); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]); } } for (int iy = 0; iy < nrc_y; ++iy) { @@ -4055,15 +4063,15 @@ static void mul_mat_q4_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn } #else template -static void mul_mat_q4_0_r8_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - mul_mat_q4_0_r8_q8_1_avx2(n, vx, bx, info, nrc_x); +static void mul_mat_q4_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + mul_mat_q4_0_r8_q8_2_avx2(n, vx, bx, info, nrc_x); } #endif template -static void mul_mat_q5_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_q5_0_r4_q8_2_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); - Q8 q8(info); + Q8 q8(info); auto m4 = _mm256_set1_epi8(0xf); auto m5 = _mm256_set1_epi8(0x10); #ifndef HAVE_FANCY_SIMD @@ -4110,7 +4118,7 @@ static void mul_mat_q5_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D const block_q5_0_r4 * iq5 = (const block_q5_0_r4 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)); + auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16)); _mm256_storeu_ps(d8 + 8*iy, _mm256_mul_ps(mscale, scales)); } for (int k = 0; k < 4; ++k) { @@ -4128,9 +4136,10 @@ static void mul_mat_q5_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D for (int iy = 0; iy < nrc_y; ++iy) { auto qy = (const block_q8_1 *)q8.y[iy]; auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-8.f*GGML_FP16_TO_FP32(qy[ib].s)), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-8.f*GGML_BF16_TO_FP32(s)), acc[iy]); } } for (int iy = 0; iy < nrc_y; ++iy) { @@ -4143,12 +4152,12 @@ static void mul_mat_q5_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D #ifdef HAVE_FANCY_SIMD template -static void mul_mat_q5_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_q5_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { if constexpr (nrc_y == 1) { - mul_mat_q5_0_r4_q8_1_avx2<1>(n, vx, bx, info, nrc_x); + mul_mat_q5_0_r4_q8_2_avx2<1>(n, vx, bx, info, nrc_x); } else { GGML_ASSERT(nrc_x%8 == 0); - Q8 q8(info); + Q8 q8(info); auto m4 = _mm512_set1_epi8(0xf); auto m5 = _mm512_set1_epi8(0x10); int nb = n / QK5_0; @@ -4190,7 +4199,7 @@ static void mul_mat_q5_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn const block_q5_0_r4 * iq5h = (const block_q5_0_r4 *)((const char *)vx + (ix+4)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - _mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d))); + _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16))); } for (int k = 0; k < 4; ++k) { auto scales = prepare(iq5l[4*ib4+k], iq5h[4*ib4+k]); @@ -4207,9 +4216,10 @@ static void mul_mat_q5_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn for (int iy = 0; iy < nrc_y; ++iy) { auto qy = (const block_q8_1 *)q8.y[iy]; auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); - auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)); + ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; + auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d)); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]); } } for (int iy = 0; iy < nrc_y; ++iy) { @@ -4225,15 +4235,15 @@ static void mul_mat_q5_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn } #else template -static void mul_mat_q5_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - mul_mat_q5_0_r4_q8_1_avx2(n, vx, bx, info, nrc_x); +static void mul_mat_q5_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + mul_mat_q5_0_r4_q8_2_avx2(n, vx, bx, info, nrc_x); } #endif template -static void mul_mat_q6_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_q6_0_r4_q8_2_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); - Q8 q8(info); + Q8 q8(info); auto m4 = _mm256_set1_epi8(0xf); auto m6 = _mm256_set1_epi8(0x30); auto mscale = _mm256_set_m128(_mm_set1_ps(-16.f), _mm_set1_ps(1.f)); @@ -4278,7 +4288,7 @@ static void mul_mat_q6_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D const block_q6_0_r4 * iq6 = (const block_q6_0_r4 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)); + auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16)); _mm256_storeu_ps(d8 + 8*iy, _mm256_mul_ps(scales, mscale)); } for (int k = 0; k < 4; ++k) { @@ -4296,9 +4306,10 @@ static void mul_mat_q6_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D for (int iy = 0; iy < nrc_y; ++iy) { auto qy = (const block_q8_1 *)q8.y[iy]; auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-16.f*GGML_FP16_TO_FP32(qy[ib].s)), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-16.f*GGML_BF16_TO_FP32(s)), acc[iy]); } } @@ -4312,12 +4323,12 @@ static void mul_mat_q6_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const D #ifdef HAVE_FANCY_SIMD template -static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_q6_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { if constexpr (nrc_y == 1) { - mul_mat_q6_0_r4_q8_1_avx2<1>(n, vx, bx, info, nrc_x); + mul_mat_q6_0_r4_q8_2_avx2<1>(n, vx, bx, info, nrc_x); } else { GGML_ASSERT(nrc_x%8 == 0); - Q8 q8(info); + Q8 q8(info); auto m4 = _mm512_set1_epi8(0xf); auto m6 = _mm512_set1_epi8(0x30); int nb = n / QK6_0; @@ -4357,7 +4368,7 @@ static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn const block_q6_0_r4 * iq6h = (const block_q6_0_r4 *)((const char *)vx + (ix+4)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)); + auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16)); _mm256_storeu_ps(d8 + 8*iy, scales); } for (int k = 0; k < 4; ++k) { @@ -4375,9 +4386,10 @@ static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn for (int iy = 0; iy < nrc_y; ++iy) { auto qy = (const block_q8_1 *)q8.y[iy]; auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); - auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)); + ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; + auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d)); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]); } } for (int iy = 0; iy < nrc_y; ++iy) { @@ -4393,8 +4405,8 @@ static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataIn } #else template -static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - mul_mat_q6_0_r4_q8_1_avx2(n, vx, bx, info, nrc_x); +static void mul_mat_q6_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + mul_mat_q6_0_r4_q8_2_avx2(n, vx, bx, info, nrc_x); } #endif @@ -4437,20 +4449,12 @@ inline __m256i q8_0_r8_dot_product(const uint8_t * x, const int8_t * y, __m256i for (int i = 0; i < 8; ++i) { qx[i] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)x+i), _mm256_set1_epi8(127)); } - //qx[0] = _mm256_loadu_si256((const __m256i *)x+0); - //qx[1] = _mm256_loadu_si256((const __m256i *)x+1); - //qx[2] = _mm256_loadu_si256((const __m256i *)x+2); - //qx[3] = _mm256_loadu_si256((const __m256i *)x+3); - //qx[4] = _mm256_loadu_si256((const __m256i *)x+4); - //qx[5] = _mm256_loadu_si256((const __m256i *)x+5); - //qx[6] = _mm256_loadu_si256((const __m256i *)x+6); - //qx[7] = _mm256_loadu_si256((const __m256i *)x+7); return qx_r8_q8_dot_product(qx, y); } template -static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%16 == 0); - Q8 q8(info); + Q8 q8(info); int nb = n / QK8_0; if constexpr (nrc_y == 1) { __m256 acc[2] = {}; @@ -4459,7 +4463,8 @@ static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn for (int ix = 0; ix < nrc_x; ix += 8) { const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { - _mm256_storeu_ps(d8, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d))); + auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d)), 16); + _mm256_storeu_ps(d8, _mm256_castsi256_ps(aux)); for (int k = 0; k < 4; ++k) { auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d)); auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[4*ib4+k].qs, q8.y[0][ib4].qs+32*k, qx); @@ -4473,9 +4478,10 @@ static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn for (int ib = 4*(nb/4); ib < nb; ++ib) { auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d)); auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[ib].qs, qy[ib].qs, qx); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s; + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]); - acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[1]); + acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(s)), acc[1]); } } info.store(ix, 0, _mm256_fmadd_ps(_mm256_set1_ps(-127.f), acc[1], acc[0])); @@ -4490,7 +4496,8 @@ static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn const block_q8_0_r8 * q8h = (const block_q8_0_r8 *)((const char *)vx + (ix+8)*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - _mm256_storeu_ps(d8+8*iy, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d))); + auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16); + _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(aux)); } for (int k = 0; k < 4; ++k) { auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8l[4*ib4+k].d)); @@ -4521,9 +4528,10 @@ static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn for (int iy = 0; iy < nrc_y; ++iy) { auto qy = (const block_q8_1 *)q8.y[iy]; auto sumi = qx_r8_q8_dot_product(qx, qy[ib].qs); - auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].d)); + ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s; + auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d)); acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_FP16_TO_FP32(qy[ib].s)), acc[2*iy+1]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]); } } for (int iy = 0; iy < nrc_y; ++iy) { @@ -4536,9 +4544,9 @@ static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn } #else template -static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%8 == 0); - Q8 q8(info); + Q8 q8(info); auto m1 = _mm256_set1_epi16(1); int nb = n / QK8_0; __m256 acc[nrc_y] = {}; @@ -4561,7 +4569,7 @@ static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d)); + auto scales = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d)), 16)); _mm_storeu_ps(d8 + 4*iy, scales); } for (int k = 0; k < 4; ++k) { @@ -4593,9 +4601,9 @@ static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn sx[j] = _mm256_sign_epi8(qx[j], qx[j]); } for (int iy = 0; iy < nrc_y; ++iy) { - auto qy = (const block_q8_1 *)q8.y[iy]; + auto qy = (const block_q8_2 *)q8.y[iy]; auto sumi = dot(qy[ib].qs); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(ggml_bf16_t{qy[ib].d}))); acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); } for (int j = 0; j < 4; ++j) { @@ -4603,9 +4611,9 @@ static void mul_mat_q8_0_r8_q8_1(int n, const void * vx, size_t bx, const DataIn sx[j] = _mm256_sign_epi8(qx[j], qx[j]); } for (int iy = 0; iy < nrc_y; ++iy) { - auto qy = (const block_q8_1 *)q8.y[iy]; + auto qy = (const block_q8_2 *)q8.y[iy]; auto sumi = dot(qy[ib].qs+16); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(qy[ib].d))); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(ggml_bf16_t{qy[ib].d}))); acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); } } @@ -8199,6 +8207,29 @@ struct ScaleHelperQ_0_1 { const __m128 min = _mm_set1_ps(float(-min_value)); }; +//template +//struct ScaleHelperQ_0_2 { +// ggml_bf16_t scales8[4]; +// template +// inline __m256 prepare4(const Q * y) { +// for (int j = 0; j < 4; ++j) scales8[j] = y[j].d; +// auto s4 = _mm_castsi128_ps(_mm_slli_epi16(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)scales8)), 16)); +// return _mm256_set_m128(_mm_mul_ps(s4, min), s4); +// } +// template +// inline __m256 prepare4(__m256 other_scales, const Q * y) { +// return _mm_mul256_ps(other_scales, prepare4(y)); +// } +// template inline std::pair prepare1(const Q * y) const { +// float d = GGML_BF16_TO_FP32(y->d); +// return std::make_pair(d, -d*float(min_value)); +// } +// std::pair inline prepare1(const std::pair& dm, const block_q8_1 * y) const { +// return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s)); +// } +// const __m128 min = _mm_set1_ps(float(-min_value)); +//}; + struct ScaleHelperQ8_1 { template inline __m256 prepare4(const Q * y) { @@ -8220,6 +8251,30 @@ struct ScaleHelperQ8_1 { } }; +struct ScaleHelperQ8_2 { + template + inline __m256 prepare4(const Q * y) { + const block_q8_2_x4 * y4 = (const block_q8_2_x4 *)y; + auto aux = _mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)y4->d)); + return _mm256_castsi256_ps(_mm256_slli_epi32(aux, 16)); + } + template + inline __m256 prepare4(__m256 other_scales, const Q * y) { + return _mm256_mul_ps(other_scales, prepare4(y)); + } + template inline std::pair prepare1(const Q * y) const { + return std::make_pair(GGML_BF16_TO_FP32(y->d), GGML_BF16_TO_FP32(y->m)); + } + template inline std::pair prepare1(const std::pair& dm, const Q * y) const { + ggml_bf16_t d, s; d.bits = y->d; s.bits = y->s; + return std::make_pair(dm.first*GGML_BF16_TO_FP32(d), dm.second*GGML_BF16_TO_FP32(s)); + } + std::pair inline prepare1(const std::pair& dm, const block_q8_2 * y) const { + ggml_bf16_t d, s; d.bits = y->d; s.bits = y->s; + return std::make_pair(dm.first*GGML_BF16_TO_FP32(d), dm.second*GGML_BF16_TO_FP32(s)); + } +}; + struct ScaleHelperQ_1 { uint32_t scales8[4]; const __m128i shuffle = _mm_set_epi16(0x0f0e, 0x0b0a, 0x0706, 0x0302, 0x0d0c, 0x0908, 0x0504, 0x0100); @@ -8320,7 +8375,8 @@ using AccumType1 = AccumT, nrc_y, is_multiple_of_4>; using Sum4Type0 = Sum4; using Sum4Type1 = Sum4; using Sum4TypeQ80 = Sum4; -using Sum4TypeQ81 = Sum4; +//using Sum4TypeQ81 = Sum4; +using Sum4TypeQ82 = Sum4; template void mul_mat_qX_q8_Helper(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) { @@ -8366,6 +8422,22 @@ void mul_mat_qX_1_q8_1_T(int n, const void * vx, size_t bx, const DataInfo& info } } +template +void mul_mat_qX_1_q8_2_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%Unpacker::block_size() == 0); + Q8 q8(info); + int nb = n/Unpacker::block_size(); + if (nb%4 == 0) { + mul_mat_qX_q8_Helper, ScaleHelperQ8_2, block_q8_2, nrc_y>( + nb, vx, bx, info, q8.y, nrc_x + ); + } else { + mul_mat_qX_q8_Helper, ScaleHelperQ8_2, block_q8_2, nrc_y>( + nb, vx, bx, info, q8.y, nrc_x + ); + } +} + struct Dequantizer4bit { const __m256i m4 = _mm256_set1_epi8(0xf); inline __m256i dequant(const uint8_t * qs) const { @@ -8494,73 +8566,6 @@ struct Q_Unpacker { } }; -struct Q8_0_x4_Unpacker_256 { - using Sum4T = Sum4TypeQ80; - inline static int block_size() { return QK8_0; } - Q8_0_x4_Unpacker_256(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {} - - const char * cx_0; - const block_q8_0_x4 * x; - size_t bx; - - __m256i qx[4]; - - inline const __m256i* quants() const { return qx; } - - inline void set_row(int ix) { x = (const block_q8_0_x4 *)(cx_0 + ix*bx); } - - inline auto set_block_4(int i) { - auto scales = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x[i].d)); - for (int j = 0; j < 4; ++j) { - qx[j] = _mm256_loadu_si256((const __m256i *)x[i].qs + j); - } - return scales; - } - inline auto set_block(int i) { - auto q8 = (const block_q8_0 *)(x + i); - qx[0] = _mm256_loadu_si256((const __m256i *)q8->qs); - return GGML_FP16_TO_FP32(q8->d); - } -}; - -#ifdef HAVE_FANCY_SIMD -struct Q8_0_x4_Unpacker_512 { - using Sum4T = Sum4TypeQ81; - inline static int block_size() { return QK8_0; } - Q8_0_x4_Unpacker_512(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {} - - const char * cx_0; - const block_q8_0_x4 * x; - size_t bx; - const __m128 min = _mm_set1_ps(-128.f); - - __m256i qx[4]; - - inline const __m256i* quants() const { return qx; } - - inline void set_row(int ix) { x = (const block_q8_0_x4 *)(cx_0 + ix*bx); } - - inline auto set_block_4(int i) { - auto scales = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x[i].d)); - for (int j = 0; j < 4; ++j) { - qx[j] = _mm256_loadu_si256((const __m256i *)x[i].qs + j); - qx[j] = _mm256_xor_si256(qx[j], _mm256_set1_epi8(-128)); - } - return _mm256_set_m128(_mm_mul_ps(scales, min), scales); - } - inline auto set_block(int i) { - auto q8 = (const block_q8_0 *)(x + i); - qx[0] = _mm256_loadu_si256((const __m256i *)q8->qs); - qx[0] = _mm256_xor_si256(qx[0], _mm256_set1_epi8(-128)); - float d = GGML_FP16_TO_FP32(q8->d); - return std::make_pair(d, -128.f*d); - } -}; -using Q8_0_x4_Unpacker = Q8_0_x4_Unpacker_512; -#else -using Q8_0_x4_Unpacker = Q8_0_x4_Unpacker_256; -#endif - struct Q8_0_Unpacker final : public Q_Unpacker { Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} using Sum4T = Sum4TypeQ80; @@ -8568,7 +8573,7 @@ struct Q8_0_Unpacker final : public Q_Unpacker, Q8_0_1_Dequantizer> { Q8_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4TypeQ81; + using Sum4T = Sum4TypeQ82; inline static int block_size() { return QK8_0; } }; struct Q4_0_Unpacker final : public Q_Unpacker { @@ -8578,12 +8583,12 @@ struct Q4_0_Unpacker final : public Q_Unpacker, Q4_0_1_Dequantizer> { Q4_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4TypeQ81; + using Sum4T = Sum4TypeQ82; inline static int block_size() { return QK4_0; } }; struct IQ4_NL_Unpacker final : public Q_Unpacker, IQ4_NL_Dequantizer> { IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4TypeQ81; + using Sum4T = Sum4TypeQ82; inline static int block_size() { return QK4_NL; } }; struct Q5_0_Unpacker final : public Q_Unpacker { @@ -8593,22 +8598,22 @@ struct Q5_0_Unpacker final : public Q_Unpacker, Q5_1_Dequantizer> { Q5_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4TypeQ81; + using Sum4T = Sum4TypeQ82; inline static int block_size() { return QK5_0; } }; struct Q4_1_Unpacker final : public Q_Unpacker { Q4_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4Type1; + using Sum4T = Sum4TypeQ82; inline static int block_size() { return QK4_1; } }; struct Q5_1_Unpacker final : public Q_Unpacker> { Q5_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4Type1; - inline static int block_size() { return QK4_1; } + using Sum4T = Sum4TypeQ82; + inline static int block_size() { return QK5_1; } }; struct Q6_0_1_Unpacker final : public Q_Unpacker, Q6_0_1_Dequantizer> { Q6_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4TypeQ81; + using Sum4T = Sum4TypeQ82; inline static int block_size() { return QK6_0; } }; @@ -9096,18 +9101,27 @@ template void MulMat::set_functions(MulMat& m) { m.funcs[6] = mul_mat_qX_0_q8_0_T; m.funcs[7] = mul_mat_qX_0_q8_0_T; } - else if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || + else if constexpr (std::is_same_v || std::is_same_v) { + m.funcs[0] = mul_mat_qX_1_q8_2_T; + m.funcs[1] = mul_mat_qX_1_q8_2_T; + m.funcs[2] = mul_mat_qX_1_q8_2_T; + m.funcs[3] = mul_mat_qX_1_q8_2_T; + m.funcs[4] = mul_mat_qX_1_q8_2_T; + m.funcs[5] = mul_mat_qX_1_q8_2_T; + m.funcs[6] = mul_mat_qX_1_q8_2_T; + m.funcs[7] = mul_mat_qX_1_q8_2_T; + } + else if constexpr (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { - m.funcs[0] = mul_mat_qX_1_q8_1_T; - m.funcs[1] = mul_mat_qX_1_q8_1_T; - m.funcs[2] = mul_mat_qX_1_q8_1_T; - m.funcs[3] = mul_mat_qX_1_q8_1_T; - m.funcs[4] = mul_mat_qX_1_q8_1_T; - m.funcs[5] = mul_mat_qX_1_q8_1_T; - m.funcs[6] = mul_mat_qX_1_q8_1_T; - m.funcs[7] = mul_mat_qX_1_q8_1_T; + m.funcs[0] = mul_mat_qX_1_q8_2_T; + m.funcs[1] = mul_mat_qX_1_q8_2_T; + m.funcs[2] = mul_mat_qX_1_q8_2_T; + m.funcs[3] = mul_mat_qX_1_q8_2_T; + m.funcs[4] = mul_mat_qX_1_q8_2_T; + m.funcs[5] = mul_mat_qX_1_q8_2_T; + m.funcs[6] = mul_mat_qX_1_q8_2_T; + m.funcs[7] = mul_mat_qX_1_q8_2_T; } else if constexpr (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || @@ -9383,33 +9397,33 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { case GGML_TYPE_Q4_0: assert (ne00 % QK4_0 == 0); MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_1_X4; + expected_typeB = GGML_TYPE_Q8_2_X4; break; case GGML_TYPE_Q4_1: assert (ne00 % QK4_1 == 0); MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_1_X4; + expected_typeB = GGML_TYPE_Q8_2_X4; break; case GGML_TYPE_Q5_0: assert (ne00 % QK5_0 == 0); MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_1_X4; + expected_typeB = GGML_TYPE_Q8_2_X4; break; case GGML_TYPE_Q5_1: assert (ne00 % QK5_1 == 0); MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_1_X4; + expected_typeB = GGML_TYPE_Q8_2_X4; break; case GGML_TYPE_Q6_0: assert (ne00 % QK6_0 == 0); MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_1_X4; + expected_typeB = GGML_TYPE_Q8_2_X4; break; case GGML_TYPE_Q8_0: assert (ne00 % QK8_0 == 0); #ifdef HAVE_FANCY_SIMD MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_1_X4; + expected_typeB = GGML_TYPE_Q8_2_X4; #else MulMat::set_functions(mm); expected_typeB = GGML_TYPE_Q8_0_X4; @@ -9418,19 +9432,19 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { case GGML_TYPE_IQ4_NL: assert (ne00 % QK4_NL == 0); MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_1_X4; + expected_typeB = GGML_TYPE_Q8_2_X4; break; case GGML_TYPE_IQ4_NL_R4: assert (ne00 % QK4_NL == 0); - mm.funcs[0] = mul_mat_iq4_nl_r4_q8_1<1>; - mm.funcs[1] = mul_mat_iq4_nl_r4_q8_1<2>; - mm.funcs[2] = mul_mat_iq4_nl_r4_q8_1<3>; - mm.funcs[3] = mul_mat_iq4_nl_r4_q8_1<4>; - mm.funcs[4] = mul_mat_iq4_nl_r4_q8_1<5>; - mm.funcs[5] = mul_mat_iq4_nl_r4_q8_1<6>; - mm.funcs[6] = mul_mat_iq4_nl_r4_q8_1<7>; - mm.funcs[7] = mul_mat_iq4_nl_r4_q8_1<8>; - expected_typeB = GGML_TYPE_Q8_1_X4; + mm.funcs[0] = mul_mat_iq4_nl_r4_q8_2<1>; + mm.funcs[1] = mul_mat_iq4_nl_r4_q8_2<2>; + mm.funcs[2] = mul_mat_iq4_nl_r4_q8_2<3>; + mm.funcs[3] = mul_mat_iq4_nl_r4_q8_2<4>; + mm.funcs[4] = mul_mat_iq4_nl_r4_q8_2<5>; + mm.funcs[5] = mul_mat_iq4_nl_r4_q8_2<6>; + mm.funcs[6] = mul_mat_iq4_nl_r4_q8_2<7>; + mm.funcs[7] = mul_mat_iq4_nl_r4_q8_2<8>; + expected_typeB = GGML_TYPE_Q8_2_X4; break; case GGML_TYPE_IQ4_XS_R8: assert (ne00 % QK_K == 0); @@ -9685,54 +9699,54 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { break; case GGML_TYPE_Q4_0_R8: assert (ne00 % QK4_NL == 0); - mm.funcs[0] = mul_mat_q4_0_r8_q8_1<1>; - mm.funcs[1] = mul_mat_q4_0_r8_q8_1<2>; - mm.funcs[2] = mul_mat_q4_0_r8_q8_1<3>; - mm.funcs[3] = mul_mat_q4_0_r8_q8_1<4>; - mm.funcs[4] = mul_mat_q4_0_r8_q8_1<5>; - mm.funcs[5] = mul_mat_q4_0_r8_q8_1<6>; - mm.funcs[6] = mul_mat_q4_0_r8_q8_1<7>; - mm.funcs[7] = mul_mat_q4_0_r8_q8_1<8>; + mm.funcs[0] = mul_mat_q4_0_r8_q8_2<1>; + mm.funcs[1] = mul_mat_q4_0_r8_q8_2<2>; + mm.funcs[2] = mul_mat_q4_0_r8_q8_2<3>; + mm.funcs[3] = mul_mat_q4_0_r8_q8_2<4>; + mm.funcs[4] = mul_mat_q4_0_r8_q8_2<5>; + mm.funcs[5] = mul_mat_q4_0_r8_q8_2<6>; + mm.funcs[6] = mul_mat_q4_0_r8_q8_2<7>; + mm.funcs[7] = mul_mat_q4_0_r8_q8_2<8>; #ifdef HAVE_FANCY_SIMD - mm.func16 = mul_mat_q4_0_r8_q8_1<16>; + mm.func16 = mul_mat_q4_0_r8_q8_2<16>; #endif - expected_typeB = GGML_TYPE_Q8_1_X4; + expected_typeB = GGML_TYPE_Q8_2_X4; break; case GGML_TYPE_Q5_0_R4: assert (ne00 % QK4_NL == 0); - mm.funcs[0] = mul_mat_q5_0_r4_q8_1<1>; - mm.funcs[1] = mul_mat_q5_0_r4_q8_1<2>; - mm.funcs[2] = mul_mat_q5_0_r4_q8_1<3>; - mm.funcs[3] = mul_mat_q5_0_r4_q8_1<4>; - mm.funcs[4] = mul_mat_q5_0_r4_q8_1<5>; - mm.funcs[5] = mul_mat_q5_0_r4_q8_1<6>; - mm.funcs[6] = mul_mat_q5_0_r4_q8_1<7>; - mm.funcs[7] = mul_mat_q5_0_r4_q8_1<8>; - expected_typeB = GGML_TYPE_Q8_1_X4; + mm.funcs[0] = mul_mat_q5_0_r4_q8_2<1>; + mm.funcs[1] = mul_mat_q5_0_r4_q8_2<2>; + mm.funcs[2] = mul_mat_q5_0_r4_q8_2<3>; + mm.funcs[3] = mul_mat_q5_0_r4_q8_2<4>; + mm.funcs[4] = mul_mat_q5_0_r4_q8_2<5>; + mm.funcs[5] = mul_mat_q5_0_r4_q8_2<6>; + mm.funcs[6] = mul_mat_q5_0_r4_q8_2<7>; + mm.funcs[7] = mul_mat_q5_0_r4_q8_2<8>; + expected_typeB = GGML_TYPE_Q8_2_X4; break; case GGML_TYPE_Q6_0_R4: assert (ne00 % QK4_NL == 0); - mm.funcs[0] = mul_mat_q6_0_r4_q8_1<1>; - mm.funcs[1] = mul_mat_q6_0_r4_q8_1<2>; - mm.funcs[2] = mul_mat_q6_0_r4_q8_1<3>; - mm.funcs[3] = mul_mat_q6_0_r4_q8_1<4>; - mm.funcs[4] = mul_mat_q6_0_r4_q8_1<5>; - mm.funcs[5] = mul_mat_q6_0_r4_q8_1<6>; - mm.funcs[6] = mul_mat_q6_0_r4_q8_1<7>; - mm.funcs[7] = mul_mat_q6_0_r4_q8_1<8>; - expected_typeB = GGML_TYPE_Q8_1_X4; + mm.funcs[0] = mul_mat_q6_0_r4_q8_2<1>; + mm.funcs[1] = mul_mat_q6_0_r4_q8_2<2>; + mm.funcs[2] = mul_mat_q6_0_r4_q8_2<3>; + mm.funcs[3] = mul_mat_q6_0_r4_q8_2<4>; + mm.funcs[4] = mul_mat_q6_0_r4_q8_2<5>; + mm.funcs[5] = mul_mat_q6_0_r4_q8_2<6>; + mm.funcs[6] = mul_mat_q6_0_r4_q8_2<7>; + mm.funcs[7] = mul_mat_q6_0_r4_q8_2<8>; + expected_typeB = GGML_TYPE_Q8_2_X4; break; case GGML_TYPE_Q8_0_R8: assert (ne00 % QK4_NL == 0); - mm.funcs[0] = mul_mat_q8_0_r8_q8_1<1>; - mm.funcs[1] = mul_mat_q8_0_r8_q8_1<2>; - mm.funcs[2] = mul_mat_q8_0_r8_q8_1<3>; - mm.funcs[3] = mul_mat_q8_0_r8_q8_1<4>; - mm.funcs[4] = mul_mat_q8_0_r8_q8_1<5>; - mm.funcs[5] = mul_mat_q8_0_r8_q8_1<6>; - mm.funcs[6] = mul_mat_q8_0_r8_q8_1<7>; - mm.funcs[7] = mul_mat_q8_0_r8_q8_1<8>; - expected_typeB = GGML_TYPE_Q8_1_X4; + mm.funcs[0] = mul_mat_q8_0_r8_q8_2<1>; + mm.funcs[1] = mul_mat_q8_0_r8_q8_2<2>; + mm.funcs[2] = mul_mat_q8_0_r8_q8_2<3>; + mm.funcs[3] = mul_mat_q8_0_r8_q8_2<4>; + mm.funcs[4] = mul_mat_q8_0_r8_q8_2<5>; + mm.funcs[5] = mul_mat_q8_0_r8_q8_2<6>; + mm.funcs[6] = mul_mat_q8_0_r8_q8_2<7>; + mm.funcs[7] = mul_mat_q8_0_r8_q8_2<8>; + expected_typeB = GGML_TYPE_Q8_2_X4; break; case GGML_TYPE_IQ1_S: mm.funcs[0] = mul_mat_iq1_s_q8_K<1>; @@ -15219,8 +15233,8 @@ template struct HelperQ80 final : public BaseHelper { using Base = BaseHelper; #ifdef HAVE_FANCY_SIMD - using block_q8 = block_q8_1; - constexpr static int block_size_q = QK8_1; + using block_q8 = block_q8_2; + constexpr static int block_size_q = QK8_2; #else using block_q8 = block_q8_0; constexpr static int block_size_q = QK8_0; @@ -15268,6 +15282,15 @@ struct HelperQ80 final : public BaseHelper { } } + static inline void convert(int nq, int stride_q, const float * q, block_q8_2 * y) { + //GGML_ASSERT(nq <= step); Why did I have this assert? + for (int i = 0; i < nq; ++i) { + quantize_row_q8_2_x4(q, y, D); + q += stride_q; + y += D/QK8_2; + } + } + static inline void convert(int nq, int stride_q, const float * q, block_q8_KV * y) { for (int i = 0; i < nq; ++i) { quantize_row_q8_KV(q, y, D); @@ -15281,8 +15304,8 @@ template struct HelperQ80R8 : public BaseHelper { using Base = BaseHelper; #ifdef __AVX2__ - constexpr static int block_size_q = QK8_1; - using block_q8 = block_q8_1; + constexpr static int block_size_q = QK8_2; + using block_q8 = block_q8_2; #else constexpr static int block_size_q = QK8_0; using block_q8 = block_q8_0; @@ -15491,8 +15514,8 @@ struct HelperQ8KVR8 : public BaseHelper { template struct HelperQ40 final : public BaseHelper { using Base = BaseHelper; - using block_q8 = block_q8_0; - constexpr static int block_size_q = QK8_0; + using block_q8 = block_q8_2; + constexpr static int block_size_q = QK8_2; HelperQ40(const char * data, int stride) : Base(data, stride) {} // Needed for v * softmax(k * q) @@ -15584,8 +15607,8 @@ struct HelperIQ4nl final : public BaseHelper { constexpr static int block_size_q = QK8_0; #else HelperIQ4nl(const char * data, int stride) : Base(data, stride) {} - using block_q8 = block_q8_1; - constexpr static int block_size_q = QK8_1; + using block_q8 = block_q8_2; + constexpr static int block_size_q = QK8_2; #endif // Needed for v * softmax(k * q) @@ -15631,8 +15654,8 @@ struct HelperQ60 final : public BaseHelper { using block_q8 = block_q8_0; constexpr static int block_size_q = QK8_0; #else - using block_q8 = block_q8_1; - constexpr static int block_size_q = QK8_1; + using block_q8 = block_q8_2; + constexpr static int block_size_q = QK8_2; #endif using Base = BaseHelper; HelperQ60(const char * data, int stride) : Base(data, stride) {} @@ -16350,7 +16373,7 @@ struct FlashQKfp32 { MAKE_FUNCS(mul_mat_qX_0_q8_0>) { @@ -16383,7 +16406,7 @@ struct FlashQKfp32 { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_0_q8_0>) { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_1_q8_1>) { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_0_q8_0 +void quantize_row_q8_1_x4_T(const float * x, Block * y, int64_t k) { assert(k % QK8_1 == 0); const int nb = k / QK8_1; const int nb4 = 4*(nb/4); - block_q8_1 * y = (block_q8_1 *)vy; - block_q8_1_x4 * y4 = (block_q8_1_x4 *)vy; + Block_x4 * y4 = (Block_x4 *)y; #if defined(__aarch64__) for (int i = 0; i < nb; i++) { int i4 = i/4, ir = i%4; @@ -851,10 +852,18 @@ void quantize_row_q8_1_x4(const float * x, void * vy, int64_t k) { accv = vaddq_s32(accv, vi); } - if (i < nb4) { - y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); + if constexpr (std::is_same_v) { + if (i < nb4) { + y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); + } else { + y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); + } } else { - y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); + if (i < nb4) { + y4[i4].s[ir] = vaddvq_s32(accv); + } else { + y[i].s = vaddvq_s32(accv); + } } } #else @@ -880,13 +889,25 @@ void quantize_row_q8_1_x4(const float * x, void * vy, int64_t k) { const float max_scalar = _mm_cvtss_f32( max4 ); // Quantize these floats - const float d = max_scalar / 127.f; - if (i < nb4) { - y4[i4].d[ir] = GGML_FP32_TO_FP16(d); + float d = max_scalar / 127.f; + if constexpr (std::is_same_v) { + if (i < nb4) { + y4[i4].d[ir] = GGML_FP32_TO_FP16(d); + } else { + y[i].d = GGML_FP32_TO_FP16(d); + } } else { - y[i].d = GGML_FP32_TO_FP16(d); + if (i < nb4) { + auto t = GGML_FP32_TO_BF16(d); + y4[i4].d[ir] = t.bits; + d = ggml_bf16_to_fp32(t); + } else { + auto t = GGML_FP32_TO_BF16(d); + y[i].d = t.bits; + d = ggml_bf16_to_fp32(t); + } } - const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; + const float id = d > 0 ? 1/d : 0.f; const __m256 mul = _mm256_set1_ps( id ); // Apply the multiplier @@ -908,10 +929,19 @@ void quantize_row_q8_1_x4(const float * x, void * vy, int64_t k) { __m256i i3 = _mm256_cvtps_epi32( v3 ); // Compute the sum of the quants and set y[i].s - if (i < nb4) { - y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); + int isum = hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); + if constexpr (std::is_same_v) { + if (i < nb4) { + y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * isum); + } else { + y[i].s = GGML_FP32_TO_FP16(d * isum); + } } else { - y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); + if (i < nb4) { + y4[i4].d[ir+4] = GGML_FP32_TO_BF16(d * isum).bits; + } else { + y[i].s = GGML_FP32_TO_BF16(d * isum).bits; + } } // Convert int32 to int16 @@ -934,6 +964,15 @@ void quantize_row_q8_1_x4(const float * x, void * vy, int64_t k) { } #endif } +} + +void quantize_row_q8_1_x4(const float * x, void * vy, int64_t k) { + quantize_row_q8_1_x4_T(x, (block_q8_1 *)vy, k); +} + +void quantize_row_q8_2_x4(const float * x, void * vy, int64_t k) { + quantize_row_q8_1_x4_T(x, (block_q8_2 *)vy, k); +} // // ============================================== iq2_K diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index dd148f2e..478bd0de 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -238,6 +238,7 @@ void quantize_row_q8_K32(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, void quantize_row_q8_KR8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_0_x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_1_x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_2_x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void repack_f32_bf16_r16 (const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row); void repack_bf16_bf16_r16(const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row); From 23b0addb34d8942baedc6f968460560392feadd3 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 27 Mar 2025 10:48:52 +0100 Subject: [PATCH 062/132] Make sure tensor row size is multiple of block size also when quantizing with --pure (#294) * WIP - not working * q8_0 without bells and wistles works * It works for q8_0 * Use bf16 instead of f16,int16 * q4_0_r8 * q5_0_r4 * q6_0_r4 * Also q4_1 and q5_1 * Add check if selected type is possible with --pure I often want to quantize with --pure to see quantization performance without quantization mixes. But for models where there qre tensors with row sizes that are not multiple of 256, this results in a crash for k- and i-quants. Hence, lets add a check if the quant selected via --pure is applicable, and change it if not. --------- Co-authored-by: Iwan Kawrakow --- src/llama.cpp | 167 +++++++++++++++++++++++++------------------------- 1 file changed, 83 insertions(+), 84 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 24737265..779fe317 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -16762,6 +16762,78 @@ static void llama_tensor_dequantize_internal( workers.clear(); } +static ggml_type change_type_if_necessar(ggml_type new_type, int nx, int ny) { + bool convert_incompatible_tensor = false; + if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K || + new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || new_type == GGML_TYPE_IQ4_XS || + new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S || + new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S || + new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K || new_type == GGML_TYPE_IQ2_K || + new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K || new_type == GGML_TYPE_Q4_K_R4 || + new_type == GGML_TYPE_IQ6_K || new_type == GGML_TYPE_IQ4_KS || new_type == GGML_TYPE_IQ4_XS_R8 || + new_type == GGML_TYPE_IQ2_KS || new_type == GGML_TYPE_IQ4_KSS || new_type == GGML_TYPE_Q6_K_R4 || + new_type == GGML_TYPE_Q5_K_R4 || new_type == GGML_TYPE_Q3_K_R4 || new_type == GGML_TYPE_Q2_K_R4 || + new_type == GGML_TYPE_IQ4_K_R4|| new_type == GGML_TYPE_Q8_K_R8 || new_type == GGML_TYPE_IQ3_K_R4|| + new_type == GGML_TYPE_IQ2_K_R4|| new_type == GGML_TYPE_IQ5_K_R4|| new_type == GGML_TYPE_IQ4_KS_R4 || + new_type == GGML_TYPE_IQ3_XXS_R4 || new_type == GGML_TYPE_IQ2_XXS_R4 || new_type == GGML_TYPE_IQ2_XS_R4 || + new_type == GGML_TYPE_IQ2_S_R4|| new_type == GGML_TYPE_IQ3_S_R4) { + if (nx % QK_K != 0) { + LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for %s", __func__, nx, ny, QK_K, ggml_type_name(new_type)); + convert_incompatible_tensor = true; + } + } + if (new_type == GGML_TYPE_IQ1_BN || new_type == GGML_TYPE_IQ2_BN || new_type == GGML_TYPE_IQ2_BN_R4) { + if (nx % QK_IQ1BN != 0) { + convert_incompatible_tensor = true; + } + } + if (convert_incompatible_tensor) { + switch (new_type) { + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XXS_R4: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_XS_R4: + case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ2_S_R4: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_XXS_R4: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ3_S_R4: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q2_K_R4: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_K_R4: + case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ2_K_R4: + case GGML_TYPE_IQ3_K: + case GGML_TYPE_IQ3_K_R4: + case GGML_TYPE_IQ4_KSS: + case GGML_TYPE_IQ4_KS: + case GGML_TYPE_IQ4_KS_R4: + case GGML_TYPE_IQ4_XS_R8: + case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; + case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ4_K_R4: + case GGML_TYPE_Q4_K_R4: + case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; + case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ5_K_R4: + case GGML_TYPE_Q5_K_R4: + case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q6_0; break; + case GGML_TYPE_IQ6_K: + case GGML_TYPE_Q6_K_R4: + case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; + default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); + } + LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type)); + } + return new_type; +} + static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) { const std::string name = ggml_get_name(tensor); @@ -17260,90 +17332,10 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n LLAMA_LOG_INFO("Using custom type %s for tensor %s\n", ggml_type_name(new_type), name.c_str()); } - // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - //} - // IK: let's remove this, else Q2_K is almost the same as Q3_K_S - //else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) { - // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - //} - // This can be used to reduce the size of the Q5_K_S model. - // The associated PPL increase is fully in line with the size reduction - //else { - // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K; - //} - bool convert_incompatible_tensor = false; - if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K || - new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || new_type == GGML_TYPE_IQ4_XS || - new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S || - new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S || - new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K || new_type == GGML_TYPE_IQ2_K || - new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K || new_type == GGML_TYPE_Q4_K_R4 || - new_type == GGML_TYPE_IQ6_K || new_type == GGML_TYPE_IQ4_KS || new_type == GGML_TYPE_IQ4_XS_R8 || - new_type == GGML_TYPE_IQ2_KS || new_type == GGML_TYPE_IQ4_KSS || new_type == GGML_TYPE_Q6_K_R4 || - new_type == GGML_TYPE_Q5_K_R4 || new_type == GGML_TYPE_Q3_K_R4 || new_type == GGML_TYPE_Q2_K_R4 || - new_type == GGML_TYPE_IQ4_K_R4|| new_type == GGML_TYPE_Q8_K_R8 || new_type == GGML_TYPE_IQ3_K_R4|| - new_type == GGML_TYPE_IQ2_K_R4|| new_type == GGML_TYPE_IQ5_K_R4|| new_type == GGML_TYPE_IQ4_KS_R4 || - new_type == GGML_TYPE_IQ3_XXS_R4 || new_type == GGML_TYPE_IQ2_XXS_R4 || new_type == GGML_TYPE_IQ2_XS_R4 || - new_type == GGML_TYPE_IQ2_S_R4|| new_type == GGML_TYPE_IQ3_S_R4) { - int nx = tensor->ne[0]; - int ny = tensor->ne[1]; - if (nx % QK_K != 0) { - LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for %s", __func__, nx, ny, QK_K, ggml_type_name(new_type)); - convert_incompatible_tensor = true; - } else { - ++qs.n_k_quantized; - } - } - if (new_type == GGML_TYPE_IQ1_BN || new_type == GGML_TYPE_IQ2_BN || new_type == GGML_TYPE_IQ2_BN_R4) { - int nx = tensor->ne[0]; - if (nx % QK_IQ1BN != 0) { - convert_incompatible_tensor = true; - } - } - if (convert_incompatible_tensor) { - switch (new_type) { - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XXS_R4: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_XS_R4: - case GGML_TYPE_IQ2_KS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ2_S_R4: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ3_XXS_R4: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ3_S_R4: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q2_K_R4: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q3_K_R4: - case GGML_TYPE_IQ2_K: - case GGML_TYPE_IQ2_K_R4: - case GGML_TYPE_IQ3_K: - case GGML_TYPE_IQ3_K_R4: - case GGML_TYPE_IQ4_KSS: - case GGML_TYPE_IQ4_KS: - case GGML_TYPE_IQ4_KS_R4: - case GGML_TYPE_IQ4_XS_R8: - case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; - case GGML_TYPE_IQ4_K: - case GGML_TYPE_IQ4_K_R4: - case GGML_TYPE_Q4_K_R4: - case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; - case GGML_TYPE_IQ5_K: - case GGML_TYPE_IQ5_K_R4: - case GGML_TYPE_Q5_K_R4: - case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q6_0; break; - case GGML_TYPE_IQ6_K: - case GGML_TYPE_Q6_K_R4: - case GGML_TYPE_Q8_K_R8: - case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; - default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); - } - LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type)); + auto working_type = change_type_if_necessar(new_type, tensor->ne[0], tensor->ne[1]); + if (working_type != new_type) { ++qs.n_fallback; + new_type = working_type; } return new_type; @@ -17848,7 +17840,14 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } // get more optimal quantization type based on the tensor shape, layer, etc. - if (!params->pure && ggml_is_quantized(default_type)) { + if (params->pure) { + auto working_type = change_type_if_necessar(new_type, tensor->ne[0], tensor->ne[1]); + if (working_type != new_type) { + ++qs.n_fallback; + new_type = working_type; + } + } + else if (ggml_is_quantized(default_type)) { new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); } if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { From 4819257ce66a680608cf9c7871156041d00eb7da Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 29 Mar 2025 08:09:52 +0100 Subject: [PATCH 063/132] Quantization improvements (#295) * Better make_qx_quants Tested with q4_0 and q3_K (pure, imatrix), and the improvement is quite significant. * Sae for iq4_nl, iq4_xs --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-quants.c | 131 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 126 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index e39cf4aa..b6c89e1d 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -1768,10 +1768,8 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * float scale = suml2 ? sumlx/suml2 : 0.0f; if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale; float best = scale * sumlx; + float best_sumlx = sumlx, best_suml2 = suml2; for (int is = -9; is <= 9; ++is) { - if (is == 0) { - continue; - } iscale = -(nmax + 0.1f*is) / max; sumlx = suml2 = 0; for (int i = 0; i < n; ++i) { @@ -1787,7 +1785,66 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * L[i] = nmax + MAX(-nmax, MIN(nmax-1, l)); } scale = sumlx/suml2; best = scale*sumlx; + best_sumlx = sumlx; best_suml2 = suml2; } + iscale = (nmax-1 + 0.1f*is) / max; + sumlx = suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i])); + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + if (suml2 > 0 && sumlx*sumlx > best*suml2) { + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + L[i] = nmax + MAX(-nmax, MIN(nmax-1, l)); + } + scale = sumlx/suml2; best = scale*sumlx; + best_sumlx = sumlx; best_suml2 = suml2; + } + } + + sumlx = best_sumlx; suml2 = best_suml2; + for (int iter = 0; iter < n*(2*nmax-1); ++iter) { + float abs_gmax = 0, gmax = 0; + int best_j = -1; + for (int j = 0; j < n; ++j) { + float w = qw ? qw[j] : rmse_type == 1 ? x[j] * x[j] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[j]) : sqrtf(fabsf(x[j])); + int l = L[j] - nmax; + float g = scale * w * (x[j] - scale*l); + if ((g > 0 && l < nmax-1) || (g < 0 && l > -nmax)) { + float ag = fabsf(g); + if (ag > abs_gmax) { + abs_gmax = ag; gmax = g; best_j = j; + } + } + } + if (best_j < 0) break; + + float new_sumlx = sumlx, new_suml2 = suml2; + float w = qw ? qw[best_j] : rmse_type == 1 ? x[best_j] * x[best_j] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[best_j]) : sqrtf(fabsf(x[best_j])); + int l = L[best_j] - nmax; + if (gmax > 0) { + new_sumlx += w*x[best_j]; + new_suml2 += w*(2*l + 1); + l += 1; + } else { + new_sumlx -= w*x[best_j]; + new_suml2 -= w*(2*l - 1); + l -= 1; + } + if (new_suml2 > 0 && new_sumlx*new_sumlx > best*new_suml2) { + sumlx = new_sumlx; suml2 = new_suml2; + scale = sumlx/suml2; best = scale*sumlx; + L[best_j] = l + nmax; + GGML_ASSERT(L[best_j] >= 0 && L[best_j] <= 2*nmax-1); + } + else { + break; + } + } return scale; } @@ -3254,8 +3311,12 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri const int64_t nb = n_per_row/QK4_0; for (int ib = 0; ib < nb; ++ib) { const float * xb = x + QK4_0 * ib; - const float * qw = quant_weights + QK4_0 * ib; - for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + if (quant_weights) { + const float * qw = quant_weights + QK4_0 * ib; + for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < QK4_0; ++j) weight[j] = xb[j]*xb[j]; + } float d = make_qx_quants(QK4_0, 8, xb, L, 1, weight); y[ib].d = GGML_FP32_TO_FP16(d); for (int j = 0; j < 16; ++j) { @@ -14581,6 +14642,7 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block } d = sumqx/sumq2; float best = d*sumqx; + float best_sumqx = sumqx, best_sumq2 = sumq2; for (int itry = -ntry; itry <= ntry; ++itry) { id = (itry + values[0])/max; sumqx = sumq2 = 0; @@ -14594,8 +14656,67 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block } if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { d = sumqx/sumq2; best = d * sumqx; + best_sumqx = sumqx; best_sumq2 = sumq2; + for (int j = 0; j < block_size; ++j) { + float al = id*xb[j]; + Lb[j] = best_index_iq4nl(values, al); + } + } + id = (itry + values[15])/max; + sumqx = sumq2 = 0; + for (int j = 0; j < block_size; ++j) { + float al = id*xb[j]; + int l = best_index_iq4nl(values, al); + float q = values[l]; + float w = weight[j]; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d * sumqx; + best_sumqx = sumqx; best_sumq2 = sumq2; + for (int j = 0; j < block_size; ++j) { + float al = id*xb[j]; + Lb[j] = best_index_iq4nl(values, al); + } } } + sumqx = best_sumqx; sumq2 = best_sumq2; + for (int iter = 0; iter < 32*block_size; ++iter) { + float min_step = INFINITY; + int best_j = -1; int dir = 0; + for (int j = 0; j < block_size; ++j) { + float w = weight[j]; + float g = d * w * (xb[j] - d*values[Lb[j]]); + if (g > 0 && Lb[j] < 15) { + float step = (values[Lb[j]+1] - values[Lb[j]])/g; + if (step < min_step) { + min_step = step; best_j = j; dir = 1; + } + } + else if (g < 0 && Lb[j] > 0) { + float step = (values[Lb[j]-1] - values[Lb[j]])/g; + if (step < min_step) { + min_step = step; best_j = j; dir = -1; + } + } + } + if (best_j < 0) break; + + float new_sumqx = sumqx, new_sumq2 = sumq2; + float w = weight[best_j]; + new_sumqx += w*xb[best_j]*(values[Lb[best_j]+dir] - values[Lb[best_j]]); + new_sumq2 += w*(values[Lb[best_j]+dir]*values[Lb[best_j]+dir] - values[Lb[best_j]]*values[Lb[best_j]]); + if (new_sumq2 > 0 && new_sumqx*new_sumqx > best*new_sumq2) { + sumqx = new_sumqx; sumq2 = new_sumq2; + d = sumqx/sumq2; best = d*sumqx; + Lb[best_j] += dir; + } + else { + break; + } + } + scales[ib] = d; float abs_d = fabsf(d); if (abs_d > amax_scale) { From 6e5156cab5c6d2858e1ecd5bc4dc5db81c71de39 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 1 Apr 2025 08:29:25 +0200 Subject: [PATCH 064/132] Fix #300 (#301) Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_mul_mat.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index cf512ba5..1c8a991d 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -15514,8 +15514,13 @@ struct HelperQ8KVR8 : public BaseHelper { template struct HelperQ40 final : public BaseHelper { using Base = BaseHelper; +#if defined __AVX2__ using block_q8 = block_q8_2; constexpr static int block_size_q = QK8_2; +#else + using block_q8 = block_q8_0; + constexpr static int block_size_q = QK8_0; +#endif HelperQ40(const char * data, int stride) : Base(data, stride) {} // Needed for v * softmax(k * q) @@ -15558,8 +15563,8 @@ struct HelperQ40 final : public BaseHelper { template struct HelperQ41 final : public BaseHelper { using Base = BaseHelper; - using block_q8 = block_q8_1; - constexpr static int block_size_q = QK8_1; + using block_q8 = block_q8_2; + constexpr static int block_size_q = QK8_2; HelperQ41(const char * data, int stride) : Base(data, stride) {} // Needed for v * softmax(k * q) @@ -16414,7 +16419,7 @@ struct FlashQKfp32 { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_0_q8_0>) { From b07a337bfe033f00fdad3ad0e809eecd8f5f2d2c Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 1 Apr 2025 08:29:47 +0200 Subject: [PATCH 065/132] Additional guards for interleaved quants (#299) * Make sure no interleaved quants are being used for token embeddings also with `--pure` and/or `--custom-q`. * Simplify --------- Co-authored-by: Iwan Kawrakow --- src/llama.cpp | 248 +++++++++++++------------------------------------- 1 file changed, 65 insertions(+), 183 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 779fe317..2ed50b20 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -16762,7 +16762,7 @@ static void llama_tensor_dequantize_internal( workers.clear(); } -static ggml_type change_type_if_necessar(ggml_type new_type, int nx, int ny) { +static ggml_type change_type_if_necessary(ggml_type new_type, int nx, int ny) { bool convert_incompatible_tensor = false; if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K || new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || new_type == GGML_TYPE_IQ4_XS || @@ -16834,6 +16834,43 @@ static ggml_type change_type_if_necessar(ggml_type new_type, int nx, int ny) { return new_type; } +static std::pair interleaved_properties(ggml_type type) { + static std::unordered_map> k_map = { + { GGML_TYPE_Q4_0_4_4, { GGML_TYPE_Q4_0, 4} }, + { GGML_TYPE_Q4_0_4_8, { GGML_TYPE_Q4_0, 4} }, + { GGML_TYPE_Q4_0_8_8, { GGML_TYPE_Q4_0, 8} }, + { GGML_TYPE_Q4_0_R8, { GGML_TYPE_Q4_0, 8} }, + { GGML_TYPE_Q5_0_R4, { GGML_TYPE_Q5_0, 4} }, + { GGML_TYPE_Q6_0_R4, { GGML_TYPE_Q6_0, 4} }, + { GGML_TYPE_Q8_0_R8, { GGML_TYPE_Q8_0, 8} }, + { GGML_TYPE_Q2_K_R4, { GGML_TYPE_Q2_K, 4} }, + { GGML_TYPE_Q3_K_R4, { GGML_TYPE_Q3_K, 4} }, + { GGML_TYPE_Q4_K_R4, { GGML_TYPE_Q4_K, 4} }, + { GGML_TYPE_Q5_K_R4, { GGML_TYPE_Q5_K, 4} }, + { GGML_TYPE_Q6_K_R4, { GGML_TYPE_Q6_K, 4} }, + { GGML_TYPE_IQ2_XXS_R4, { GGML_TYPE_IQ2_XXS, 4} }, + { GGML_TYPE_IQ2_XS_R4, { GGML_TYPE_IQ2_XS, 4} }, + { GGML_TYPE_IQ2_S_R4, { GGML_TYPE_IQ2_S, 4} }, + { GGML_TYPE_IQ3_XXS_R4, { GGML_TYPE_IQ3_XXS, 4} }, + { GGML_TYPE_IQ3_S_R4, { GGML_TYPE_IQ3_S, 4} }, + { GGML_TYPE_IQ4_XS_R8, { GGML_TYPE_IQ4_XS, 8} }, + { GGML_TYPE_IQ4_NL_R4, { GGML_TYPE_IQ4_NL, 4} }, + { GGML_TYPE_IQ1_S_R4, { GGML_TYPE_IQ1_S, 4} }, + { GGML_TYPE_IQ1_M_R4, { GGML_TYPE_IQ1_M, 4} }, + { GGML_TYPE_IQ2_BN_R4, { GGML_TYPE_IQ2_BN, 4} }, + { GGML_TYPE_IQ2_K_R4, { GGML_TYPE_IQ2_K, 4} }, + { GGML_TYPE_IQ3_K_R4, { GGML_TYPE_IQ3_K, 4} }, + { GGML_TYPE_IQ4_K_R4, { GGML_TYPE_IQ4_K, 4} }, + { GGML_TYPE_IQ4_KS_R4, { GGML_TYPE_IQ4_KS, 4} }, + { GGML_TYPE_IQ5_K_R4, { GGML_TYPE_IQ5_K, 4} }, + { GGML_TYPE_Q8_KV_R8, { GGML_TYPE_Q8_KV, 8} }, + { GGML_TYPE_Q8_K_R8, { GGML_TYPE_Q8_K, 8} }, + { GGML_TYPE_BF16_R16, { GGML_TYPE_BF16, 16} }, + }; + if (auto it = k_map.find(type); it != k_map.end()) return it->second; + return {type, 1}; +} + static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) { const std::string name = ggml_get_name(tensor); @@ -16939,70 +16976,6 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_BN || ftype == LLAMA_FTYPE_MOSTLY_IQ2_BN || ftype == LLAMA_FTYPE_MOSTLY_IQ2_BN_R4) { new_type = GGML_TYPE_IQ4_NL; } - else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 || - new_type == GGML_TYPE_Q4_0_8_8) { - new_type = GGML_TYPE_Q4_0; - } - else if (new_type == GGML_TYPE_IQ4_NL_R4) { - new_type = GGML_TYPE_IQ4_NL; - } - else if (new_type == GGML_TYPE_IQ4_XS_R8) { - new_type = GGML_TYPE_IQ4_XS; - } - else if (new_type == GGML_TYPE_Q2_K_R4) { - new_type = GGML_TYPE_Q2_K; - } - else if (new_type == GGML_TYPE_Q3_K_R4) { - new_type = GGML_TYPE_Q3_K; - } - else if (new_type == GGML_TYPE_Q4_K_R4) { - new_type = GGML_TYPE_Q4_K; - } - else if (new_type == GGML_TYPE_Q5_K_R4) { - new_type = GGML_TYPE_Q5_K; - } - else if (new_type == GGML_TYPE_Q6_K_R4) { - new_type = GGML_TYPE_Q6_K; - } - else if (new_type == GGML_TYPE_Q8_K_R8) { - new_type = GGML_TYPE_Q8_0; - } - else if (new_type == GGML_TYPE_Q8_KV_R8) { - new_type = GGML_TYPE_Q8_0; - } - else if (new_type == GGML_TYPE_IQ2_K_R4) { - new_type = GGML_TYPE_IQ2_K; - } - else if (new_type == GGML_TYPE_IQ3_K_R4) { - new_type = GGML_TYPE_IQ3_K; - } - else if (new_type == GGML_TYPE_IQ3_S_R4) { - new_type = GGML_TYPE_IQ3_S; - } - else if (new_type == GGML_TYPE_IQ4_K_R4) { - new_type = GGML_TYPE_IQ4_K; - } - else if (new_type == GGML_TYPE_IQ5_K_R4) { - new_type = GGML_TYPE_IQ5_K; - } - else if (new_type == GGML_TYPE_IQ4_KS_R4) { - new_type = GGML_TYPE_IQ4_KS; - } - else if (new_type == GGML_TYPE_Q4_0_R8) { - new_type = GGML_TYPE_Q4_0; - } - else if (new_type == GGML_TYPE_Q5_0_R4) { - new_type = GGML_TYPE_Q5_0; - } - else if (new_type == GGML_TYPE_Q6_0_R4) { - new_type = GGML_TYPE_Q6_0; - } - else if (new_type == GGML_TYPE_Q8_0_R8) { - new_type = GGML_TYPE_Q8_0; - } - else if (new_type == GGML_TYPE_BF16_R16) { - new_type = GGML_TYPE_BF16; - } } } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M_R4) { if (name.find("attn_v.weight") != std::string::npos) { @@ -17332,12 +17305,21 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n LLAMA_LOG_INFO("Using custom type %s for tensor %s\n", ggml_type_name(new_type), name.c_str()); } - auto working_type = change_type_if_necessar(new_type, tensor->ne[0], tensor->ne[1]); + auto working_type = change_type_if_necessary(new_type, tensor->ne[0], tensor->ne[1]); if (working_type != new_type) { ++qs.n_fallback; new_type = working_type; } + if (name == "token_embd.weight") { + auto working_type = interleaved_properties(new_type).first; + if (working_type != new_type) { + printf("\n============ Token embeddings cannot be quantized with row-interleaved quants\n"); + printf("---> Changed %s to %s\n", ggml_type_name(new_type), ggml_type_name(working_type)); + new_type = working_type; + } + } + return new_type; } @@ -17834,14 +17816,12 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } if (quantize) { + new_type = default_type; - if (new_type == GGML_TYPE_BF16_R16 && strcmp(tensor->name, "token_embd.weight") == 0) { - new_type = GGML_TYPE_BF16; - } // get more optimal quantization type based on the tensor shape, layer, etc. if (params->pure) { - auto working_type = change_type_if_necessar(new_type, tensor->ne[0], tensor->ne[1]); + auto working_type = change_type_if_necessary(new_type, tensor->ne[0], tensor->ne[1]); if (working_type != new_type) { ++qs.n_fallback; new_type = working_type; @@ -17881,6 +17861,16 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s new_type = params->ffn_up_type; } + if (strcmp(tensor->name, "token_embd.weight") == 0) { + // token embeddings cannot be quantized with row-interleaved quants + auto working_type = interleaved_properties(new_type).first; + if (working_type != new_type) { + printf("\n============ Token embeddings cannot be quantized with row-interleaved quants\n"); + printf("---> Changed %s to %s\n", ggml_type_name(new_type), ggml_type_name(working_type)); + new_type = working_type; + } + } + // If we've decided to quantize to the same type the tensor is already // in then there's nothing to do. quantize = tensor->type != new_type; @@ -17965,119 +17955,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } int chunk_size_multiplier = 1; - if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 || new_type == GGML_TYPE_Q4_0_8_8) { - if ((new_type == GGML_TYPE_Q4_0_8_8) && (tensor->ne[1] % 8 != 0)) new_type = GGML_TYPE_Q4_0; - else if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q4_0; - if (new_type == GGML_TYPE_Q4_0_8_8) chunk_size_multiplier = 8; - else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8) chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ4_NL_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ4_NL; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ4_XS_R8) { - if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_IQ4_XS; - else chunk_size_multiplier = 8; - } - else if (new_type == GGML_TYPE_Q4_0_R8) { - if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q4_0; - else chunk_size_multiplier = 8; - } - else if (new_type == GGML_TYPE_Q5_0_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q5_0; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_Q6_0_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q6_0; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_Q8_0_R8) { - if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q8_0; - else chunk_size_multiplier = 8; - } - else if (new_type == GGML_TYPE_Q2_K_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q2_K; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_Q3_K_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q3_K; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_Q4_K_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q4_K; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_Q5_K_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q5_K; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_Q6_K_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q6_K; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_Q8_K_R8) { - if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q8_0; - else chunk_size_multiplier = 8; - } - else if (new_type == GGML_TYPE_Q8_KV_R8) { - if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q8_0; - else chunk_size_multiplier = 8; - } - else if (new_type == GGML_TYPE_IQ2_BN_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ2_BN; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ2_K_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ2_K; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ3_K_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ3_K; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ4_K_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ4_K; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ5_K_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ5_K; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ4_KS_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ4_KS; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ2_XXS_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ2_XXS; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ2_XS_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ2_XS; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ2_S_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ2_S; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ3_XXS_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ3_XXS; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ3_S_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ3_S; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ1_S_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ1_S; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_IQ1_M_R4) { - if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ1_M; - else chunk_size_multiplier = 4; - } - else if (new_type == GGML_TYPE_BF16_R16) { - if (tensor->ne[1] % 16 != 0) new_type = GGML_TYPE_BF16; - else chunk_size_multiplier = 16; + auto [working_type, num_rows] = interleaved_properties(new_type); + if (tensor->ne[1] % num_rows != 0) { + new_type = working_type; + } else { + chunk_size_multiplier = num_rows; } LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type)); From 190e7866db1d87a5da8b2d2b8d6619092b2ec72c Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 1 Apr 2025 10:31:06 +0200 Subject: [PATCH 066/132] Quantization improvements (2) (#302) * iq3_k: slightly better quantization Not much of a difference for most models, but this change avoids what it looks like a catastrophic failure for DeepSeek-Lite (PPL is now 7.041 vs 7.314 on main). * Small improvement for type-1 quants --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-quants.c | 133 +++++++++++++++++++++++----------- ggml/src/iqk/iqk_quantize.cpp | 62 ++++++++++++---- 2 files changed, 141 insertions(+), 54 deletions(-) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index b6c89e1d..e43a9b22 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -2199,8 +2199,9 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f float rmin, float rdelta, int nstep, bool use_mad) { float min = x[0]; float max = x[0]; - float sum_w = weights ? weights[0] : x[0]*x[0]; - float sum_x = sum_w * x[0]; + double sum_w = weights ? (double)weights[0] : (double)(x[0]*x[0]); + double sum_x = sum_w * (double)x[0]; + double sum_x2 = sum_w * (double)x[0] * (double)x[0]; #ifdef HAVE_BUGGY_APPLE_LINKER // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 for (volatile int i = 1; i < n; ++i) { @@ -2210,8 +2211,9 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f if (x[i] < min) min = x[i]; if (x[i] > max) max = x[i]; float w = weights ? weights[i] : x[i]*x[i]; - sum_w += w; - sum_x += w * x[i]; + sum_w += (double)w; + sum_x += (double)w * (double)x[i]; + sum_x2 += (double)w * (double)x[i] * (double)x[i]; } if (min > 0) { min = 0; @@ -2223,13 +2225,13 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f } float iscale = nmax/(max - min); float scale = 1/iscale; - float best_mad = 0; + double best_mad = 0; for (int i = 0; i < n; ++i) { int l = nearest_int(iscale*(x[i] - min)); L[i] = MAX(0, MIN(nmax, l)); - float diff = scale * L[i] + min - x[i]; - diff = use_mad ? fabsf(diff) : diff*diff; - float w = weights ? weights[i] : x[i]*x[i]; + double diff = (double)scale * L[i] + (double)min - (double)x[i]; + diff = use_mad ? fabs(diff) : diff*diff; + double w = weights ? (double)weights[i] : (double)(x[i]*x[i]); best_mad += w * diff; } if (nstep < 1) { @@ -2238,30 +2240,35 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f } for (int is = 0; is <= nstep; ++is) { iscale = (rmin + rdelta*is + nmax)/(max - min); - float sum_l = 0, sum_l2 = 0, sum_xl = 0; + double sum_l = 0, sum_l2 = 0, sum_xl = 0; for (int i = 0; i < n; ++i) { int l = nearest_int(iscale*(x[i] - min)); l = MAX(0, MIN(nmax, l)); Laux[i] = l; float w = weights ? weights[i] : x[i]*x[i]; - sum_l += w*l; - sum_l2 += w*l*l; - sum_xl += w*l*x[i]; + sum_l += (double)w*l; + sum_l2 += (double)w*l*l; + sum_xl += (double)w*l*(double)x[i]; } - float D = sum_w * sum_l2 - sum_l * sum_l; + double D = sum_w * sum_l2 - sum_l * sum_l; if (D > 0) { - float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D; - float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D; + double this_scale = (sum_w * sum_xl - sum_x * sum_l)/D; + double this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D; if (this_min > 0) { this_min = 0; this_scale = sum_xl / sum_l2; } - float mad = 0; - for (int i = 0; i < n; ++i) { - float diff = this_scale * Laux[i] + this_min - x[i]; - diff = use_mad ? fabsf(diff) : diff*diff; - float w = weights ? weights[i] : x[i]*x[i]; - mad += w * diff; + double mad = 0; + if (use_mad) { + for (int i = 0; i < n; ++i) { + double diff = (double)this_scale * Laux[i] + (double)this_min - (double)x[i]; + diff = fabs(diff); + double w = weights ? (double)weights[i] : (double)(x[i]*x[i]); + mad += w * diff; + } + } else { + mad = sum_x2 - 2*this_scale*sum_xl - 2*this_min*sum_x + 2*this_scale*this_min*sum_l + + this_scale*this_scale*sum_l2 + this_min*this_min*sum_w; } if (mad < best_mad) { for (int i = 0; i < n; ++i) { @@ -2273,6 +2280,57 @@ static float make_qkx3_quants(int n, int nmax, const float * restrict x, const f } } } + if (use_mad) { + *the_min = -min; + return scale; + } + + double sum_l = 0, sum_l2 = 0, sum_xl = 0; + for (int i = 0; i < n; ++i) { + int l = L[i]; + double w = weights ? (double)weights[i] : (double)(x[i]*x[i]); + sum_l += w*l; + sum_l2 += w*l*l; + sum_xl += w*l*(double)x[i]; + } + double best = 2*(double)scale*sum_xl + 2*(double)min*sum_x - 2*(double)scale*(double)min*sum_l + - (double)scale*(double)scale*sum_l2 - (double)min*(double)min*sum_w; + int last_j = -1, last_dir = 0; + for (int itry = 0; itry < nmax*n; ++itry) { + float gmax = 0; + int best_j = -1, dir = 0; + for (int j = 0; j < n; ++j) { + float g = x[j] - scale*L[j] - min; + if (g > 0 && L[j] < nmax && g > gmax) { + gmax = g; best_j = j; dir = 1; + } + else if (g < 0 && L[j] > 0 && -g > gmax) { + gmax = -g; best_j = j; dir = -1; + } + } + if (best_j < 0 || (best_j == last_j && dir == -last_dir)) break; + double w = weights ? (double)weights[best_j] : (double)(x[best_j]*x[best_j]); + sum_l += w*dir; + sum_l2 += w*(2*L[best_j]*dir + 1); + sum_xl += w*(double)x[best_j]*dir; + double D = (double)sum_w * sum_l2 - sum_l * sum_l; + if (D <= 0) break; + double this_scale = ((double)sum_w * sum_xl - (double)sum_x * sum_l)/D; + double this_min = (sum_l2 * (double)sum_x - sum_l * sum_xl)/D; + if (this_min > 0) { + this_min = 0; + this_scale = sum_xl / sum_l2; + } + if (this_scale < 0) break; + double score = 2*this_scale*sum_xl + 2*this_min*(double)sum_x - 2*this_scale*this_min*sum_l + - this_scale*this_scale*sum_l2 - this_min*this_min*(double)sum_w; + if (score <= best) break; + best = score; + scale = this_scale; + min = this_min; + L[best_j] += dir; + last_j = best_j; last_dir = dir; + } *the_min = -min; return scale; } @@ -2354,7 +2412,6 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri GGML_ASSERT(quant_weights); assert(k % QK_K == 0); const int nb = k / QK_K; - const bool requantize = true; uint8_t L[QK_K]; uint8_t Laux[16]; @@ -2368,7 +2425,7 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri memset(sw, 0, QK_K/16*sizeof(float)); float sumx2 = 0; for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j]; - float sigma2 = sumx2/QK_K; + float sigma2 = 0.75f*sumx2/QK_K; for (int j = 0; j < QK_K/16; ++j) { const float * restrict qw = quant_weights + QK_K * i + 16*j; for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]); @@ -2376,31 +2433,25 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false); } - float dm, mm; - dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw); - mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw); + float dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw); + float mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw); y[i].d = GGML_FP32_TO_FP16(dm); y[i].dmin = GGML_FP32_TO_FP16(mm); - dm = GGML_FP16_TO_FP32(y[i].d); - mm = GGML_FP16_TO_FP32(y[i].dmin); for (int j = 0; j < QK_K/16; ++j) { - y[i].scales[j] = Ls[j] | (Lm[j] << 4); - } - - if (requantize) { - for (int j = 0; j < QK_K/16; ++j) { - const float d = dm * (y[i].scales[j] & 0xF); - if (!d) continue; - const float m = mm * (y[i].scales[j] >> 4); - for (int ii = 0; ii < 16; ++ii) { - int l = nearest_int((x[16*j + ii] + m)/d); - l = MAX(0, MIN(3, l)); - L[16*j + ii] = l; - } + float d = dm*Ls[j]; + float m = mm*Lm[j]; + float id = d ? 1/d : 0.f; + for (int l = 0; l < QK_K/16; ++l) { + int q = nearest_int((x[16*j + l] + m)*id); + q = MAX(0, MIN(3, q)); + L[16*j + l] = q; } } + for (int j = 0; j < QK_K/16; ++j) { + y[i].scales[j] = Ls[j] | (Lm[j] << 4); + } for (int j = 0; j < QK_K; j += 128) { for (int l = 0; l < 32; ++l) { diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index cac1fd49..e2cea7df 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -1555,12 +1555,13 @@ inline int best_index_iq3nl(const int8_t * values, float x) { static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, const float * quant_weights) { - const int ntry = 5; + constexpr int ntry = 3; block_iq3_k * y = (block_iq3_k *)vy; float scales[QK_K/16]; float weight[16]; + uint8_t L[16]; const int8_t * shifted_values = iq3nl_values + 8; @@ -1620,7 +1621,7 @@ static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, c } bool is_shifted = false; for (int itry = -ntry; itry <= ntry; ++itry) { - id = (itry + iq3nl_values[0])/max; + id = (2*itry + iq3nl_values[0])/max; sumqx_p = sumq2_p = 0; sumqx_m = sumq2_m = 0; for (int j = 0; j < 16; ++j) { @@ -1641,7 +1642,7 @@ static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, c if (sumq2_m > 0 && sumqx_m*sumqx_m > best*sumq2_m) { d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = false; } - id = (itry + shifted_values[0])/max; + id = (2*itry + shifted_values[0])/max; sumqx_p = sumq2_p = 0; sumqx_m = sumq2_m = 0; for (int j = 0; j < 16; ++j) { @@ -1663,20 +1664,55 @@ static void quantize_row_iq3_k_impl(const float * x, void * vy, int n_per_row, c d = sumqx_m/sumq2_m; best = d * sumqx_m; is_shifted = true; } } - if (d) { - const int8_t * block_values = is_shifted ? shifted_values : iq3nl_values; - float sumqx = 0, sumq2 = 0; - id = 1/d; + if (!d) { + scales[ib] = 0; continue; + } + + const int8_t * block_values = is_shifted ? shifted_values : iq3nl_values; + float sumqx = 0, sumq2 = 0; + id = 1/d; + for (int j = 0; j < 16; ++j) { + float w = weight[j]; + float al = id*xb[j]; + int l = best_index_iq3nl(block_values, al); + L[j] = l; + float q = block_values[l]; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + if (sumq2 > 0) d = sumqx/sumq2; + + float best_d = d; + for (int iter = 0; iter < 128; ++iter) { + float gmax = 0; + int best_j = -1, dir = 0; for (int j = 0; j < 16; ++j) { float w = weight[j]; - float al = id*xb[j]; - int l = best_index_iq3nl(block_values, al); - float q = block_values[l]; - sumqx += w*q*xb[j]; - sumq2 += w*q*q; + float g = d * w * (xb[j] - d*block_values[L[j]]); + if (g > 0 && L[j] < 7) { + if (g > gmax) { + gmax = g; best_j = j; dir = 1; + } + } + else if (g < 0 && L[j] > 0) { + if (-g > gmax) { + gmax = -g; best_j = j; dir = -1; + } + } } - if (sumq2 > 0) d = sumqx/sumq2; + if (best_j < 0) break; + + float w = weight[best_j]; + sumqx += w*xb[best_j]*(block_values[L[best_j]+dir] - block_values[L[best_j]]); + sumq2 += w*(block_values[L[best_j]+dir]*block_values[L[best_j]+dir] - block_values[L[best_j]]*block_values[L[best_j]]); + L[best_j] += dir; + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + best_d = sumqx/sumq2; best = best_d*sumqx; + } + else if (iter > 8) break; + } + scales[ib] = d; if (is_shifted) extra |= (1 << ib); From 21a5b8bd2820c097c1aa03d85cb83658609c7a12 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 1 Apr 2025 13:48:20 +0200 Subject: [PATCH 067/132] Fix ARM_NEON build failure due to q8_2 (#303) Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-alloc.c | 4 ++-- ggml/src/iqk/iqk_quantize.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index d811dee6..3f2d2023 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -174,8 +174,8 @@ static size_t ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t siz // this should never happen fprintf(stderr, "%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n", __func__, size, max_avail); - fprintf(stderr, "%s: tensor was %s with %zu elements and %zu bytes\n", __func__, tensor->name, - ggml_nelements(tensor), ggml_nbytes(tensor)); + fprintf(stderr, "%s: tensor was %s with %g elements and %zu bytes\n", __func__, tensor->name, + 1.*ggml_nelements(tensor), ggml_nbytes(tensor)); GGML_ABORT("not enough space in the buffer"); } } diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index e2cea7df..7873d3fe 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -860,9 +860,9 @@ void quantize_row_q8_1_x4_T(const float * x, Block * y, int64_t k) { } } else { if (i < nb4) { - y4[i4].s[ir] = vaddvq_s32(accv); + y4[i4].d[ir+4] = GGML_FP32_TO_BF16(d * vaddvq_s32(accv)).bits; } else { - y[i].s = vaddvq_s32(accv); + y[i].s = GGML_FP32_TO_BF16(d * vaddvq_s32(accv)).bits; } } } From 6d405d1fd1bddfe31fcbf00c9c8652a0a9166887 Mon Sep 17 00:00:00 2001 From: Ikko Eltociear Ashimine Date: Wed, 2 Apr 2025 04:30:25 +0900 Subject: [PATCH 068/132] docs: update README.md (#304) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 782d0150..207d6de2 100644 --- a/README.md +++ b/README.md @@ -265,5 +265,5 @@ To reproduce these results: ## To tile or not to tile -The common wisdom for efficient matrix multiplications is to use block tiling, and this is also used here for `fp16/fp32` matrices. But block tiling does not somehow magically reduce the amount of computation that needs to get done. Performance gains are simply due to the better utilization of memory caches. When dealing with quantized matrix multiplications, there is an additional factor that comes into play: the quantized data needs to be unpacked to 8-bit integers before being used in the matrix multiplication multiply-add operations. Depending on quantization type, this unpacking can represent a significant fraction of the overall computation cost. Hence, for best performance, one would want to reuse the unpacked quants as much as possible, thus spending some fraction of the available vector registers to hold the unpacked data. But when using block tiling, one also needs a certain number of vector registers for accumulating results. For instance, on `AVX2` (16 vector registers available), for `fp16/fp32` models best performance is achieved with `2 x 6` tiles (where the `2` refers to rows in the left matrix and is measured in units of the vector register size, so 16/8 floats for `fp16/fp32`, and `6` is for the number of columns in the right matrix). Unpacking quantized data works best when done in blocks of 128 or 256 quants so that, if we wanted to keep unpacked quants for 2 rows, we would need at least 8 vector registers, thus being left with less than 8 registers for result accumulation, so at best `2 x 3` tiles. In practice one needs addition vector registers for various constants that are typically needed for de-quantization, so that, at the end, it becomes better to use `1 x N` "tiles", i.e., a row-wise multiplication where each row in the left matrix is multiplied with `N` columns in the right matrix, thus reusing the unpacked data `N` times. This (i.e., amortizing de-quantization cost) is the main mechanism for seeding up quantized matrix multiplications. Having started with quantized matrices, and having gone from tiles to a row-wise implementation after some experimentation, I did try row-wise multiplication for float matrices first. Performance was not quite as good as for block-tiling, but I did get up to 90-95% of the speed of `tinyBLAS` that way before switching the `fp16/fp32` implementation to `2 x 6` (`AVX2`) or `5 x 5` (`AVX512` and `ARM_NEON`) block-tiles. But even for for `Q8_0 x Q8_0` multiplications, where there is basically no de-quantization cost, row-wise multiplication is faster than tiling (and hence this implemeintation beats `tinyBLAS`, which uses block-tiling also for `Q8_0`). +The common wisdom for efficient matrix multiplications is to use block tiling, and this is also used here for `fp16/fp32` matrices. But block tiling does not somehow magically reduce the amount of computation that needs to get done. Performance gains are simply due to the better utilization of memory caches. When dealing with quantized matrix multiplications, there is an additional factor that comes into play: the quantized data needs to be unpacked to 8-bit integers before being used in the matrix multiplication multiply-add operations. Depending on quantization type, this unpacking can represent a significant fraction of the overall computation cost. Hence, for best performance, one would want to reuse the unpacked quants as much as possible, thus spending some fraction of the available vector registers to hold the unpacked data. But when using block tiling, one also needs a certain number of vector registers for accumulating results. For instance, on `AVX2` (16 vector registers available), for `fp16/fp32` models best performance is achieved with `2 x 6` tiles (where the `2` refers to rows in the left matrix and is measured in units of the vector register size, so 16/8 floats for `fp16/fp32`, and `6` is for the number of columns in the right matrix). Unpacking quantized data works best when done in blocks of 128 or 256 quants so that, if we wanted to keep unpacked quants for 2 rows, we would need at least 8 vector registers, thus being left with less than 8 registers for result accumulation, so at best `2 x 3` tiles. In practice one needs addition vector registers for various constants that are typically needed for de-quantization, so that, at the end, it becomes better to use `1 x N` "tiles", i.e., a row-wise multiplication where each row in the left matrix is multiplied with `N` columns in the right matrix, thus reusing the unpacked data `N` times. This (i.e., amortizing de-quantization cost) is the main mechanism for seeding up quantized matrix multiplications. Having started with quantized matrices, and having gone from tiles to a row-wise implementation after some experimentation, I did try row-wise multiplication for float matrices first. Performance was not quite as good as for block-tiling, but I did get up to 90-95% of the speed of `tinyBLAS` that way before switching the `fp16/fp32` implementation to `2 x 6` (`AVX2`) or `5 x 5` (`AVX512` and `ARM_NEON`) block-tiles. But even for for `Q8_0 x Q8_0` multiplications, where there is basically no de-quantization cost, row-wise multiplication is faster than tiling (and hence this implementation beats `tinyBLAS`, which uses block-tiling also for `Q8_0`). From 07dbc1aa06d761634419759431ebb215baf698bb Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 3 Apr 2025 07:15:49 +0200 Subject: [PATCH 069/132] Metal: much faster MoE prompt processing (#307) * MoE improvements on Metal This version beats mainline, there are things I don't understand: * Mianline has effectively gone to GEMV for MUL_MAT_ID. We can do the same, but we are 30% slower. Why? * Using actual GEMM, we beat mainline with ubtach size of 128. But then performance degrades. Why? * Some cleanup * Much better --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-metal.m | 4248 +++++++++++++++++++------------------ ggml/src/ggml-metal.metal | 144 +- 2 files changed, 2234 insertions(+), 2158 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 0498be1f..aa50d448 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -297,8 +297,9 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_COUNT }; +#define GGML_METAL_MAX_COMMAND_BUFFERS 8 + struct ggml_backend_metal_context { - int n_cb; id device; id queue; @@ -307,6 +308,26 @@ struct ggml_backend_metal_context { struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT]; + // capture state + bool capture_next_compute; + bool capture_started; + + id capture_scope; + + // command buffer state + int n_cb; // number of extra threads used to submit the command buffers + int n_nodes_0; // number of nodes submitted by the main thread + int n_nodes_1; // remaining number of nodes submitted by the n_cb threads + int n_nodes_per_cb; + + struct ggml_cgraph * gf; + + // the callback given to the thread pool + void (^encode_async)(size_t ith); + + // n_cb command buffers + 1 used by the main thread + id command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; + bool support_simdgroup_reduction; bool support_simdgroup_mm; @@ -373,7 +394,6 @@ static void * ggml_metal_host_malloc(size_t n) { const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n); if (result != 0) { GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__); - return NULL; } #endif @@ -533,6 +553,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); ctx->should_capture_next_compute = false; + ctx->capture_started = false; + ctx->capture_scope = nil; + + ctx->gf = nil; + ctx->encode_async = nil; + for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { + ctx->command_buffers[i] = nil; + } #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) if (@available(macOS 10.12, iOS 16.0, *)) { @@ -846,6 +874,8 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { [ctx->kernels[i].pipeline release]; } + Block_release(ctx->encode_async); + [ctx->queue release]; [ctx->device release]; @@ -1027,934 +1057,871 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx } } -static enum ggml_status ggml_metal_graph_compute( +static void ggml_metal_encode_node( struct ggml_backend_metal_context * ctx, - struct ggml_cgraph * gf) { + struct ggml_tensor * node, + id encoder) { - @autoreleasepool { - MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; - edesc.dispatchType = MTLDispatchTypeSerial; - // create multiple command buffers and enqueue them - // then, we encode the graph into the command buffers in parallel + struct ggml_tensor * src0 = node->src[0]; + struct ggml_tensor * src1 = node->src[1]; + struct ggml_tensor * src2 = node->src[2]; + struct ggml_tensor * dst = node; - const int n_nodes = gf->n_nodes; - const int n_cb = ctx->n_cb; - const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb; - - const bool should_capture = ctx->should_capture_next_compute; - if (should_capture) { - ctx->should_capture_next_compute = false; - - MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; - descriptor.captureObject = ctx->queue; - - NSError * error = nil; - if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { - GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]); - GGML_ABORT("capture failed"); - } + if (ggml_is_empty(dst)) { + return; } - id command_buffer_builder[n_cb]; - for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { - id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; - command_buffer_builder[cb_idx] = command_buffer; - - // always enqueue the first two command buffers - // enqueue all of the command buffers if we don't need to abort - if (cb_idx < 2 || ctx->abort_callback == NULL) { - [command_buffer enqueue]; - } + switch (dst->op) { + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_TRANSPOSE: + case GGML_OP_PERMUTE: return; // noop + default: break; } - const id *command_buffers = command_buffer_builder; + if (!ggml_metal_supports_op(ctx, dst)) { + GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst)); + GGML_ABORT("unsupported op"); + } - dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) { - const int cb_idx = iter; + const int64_t ne00 = src0 ? src0->ne[0] : 0; + const int64_t ne01 = src0 ? src0->ne[1] : 0; + const int64_t ne02 = src0 ? src0->ne[2] : 0; + const int64_t ne03 = src0 ? src0->ne[3] : 0; - size_t offs_src0 = 0; - size_t offs_src1 = 0; - size_t offs_src2 = 0; - size_t offs_dst = 0; + const uint64_t nb00 = src0 ? src0->nb[0] : 0; + const uint64_t nb01 = src0 ? src0->nb[1] : 0; + const uint64_t nb02 = src0 ? src0->nb[2] : 0; + const uint64_t nb03 = src0 ? src0->nb[3] : 0; - id command_buffer = command_buffers[cb_idx]; - id encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; + const int64_t ne10 = src1 ? src1->ne[0] : 0; + const int64_t ne11 = src1 ? src1->ne[1] : 0; + const int64_t ne12 = src1 ? src1->ne[2] : 0; + const int64_t ne13 = src1 ? src1->ne[3] : 0; - const int node_start = (cb_idx + 0) * n_nodes_per_cb; - const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes); + const uint64_t nb10 = src1 ? src1->nb[0] : 0; + const uint64_t nb11 = src1 ? src1->nb[1] : 0; + const uint64_t nb12 = src1 ? src1->nb[2] : 0; + const uint64_t nb13 = src1 ? src1->nb[3] : 0; - for (int i = node_start; i < node_end; ++i) { - if (i == -1) { - [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]; - continue; - } + const int64_t ne20 = src2 ? src2->ne[0] : 0; + const int64_t ne21 = src2 ? src2->ne[1] : 0; + const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22); + const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23); - //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); + const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); + const uint64_t nb21 = src2 ? src2->nb[1] : 0; + const uint64_t nb22 = src2 ? src2->nb[2] : 0; + const uint64_t nb23 = src2 ? src2->nb[3] : 0; - struct ggml_tensor * src0 = gf->nodes[i]->src[0]; - struct ggml_tensor * src1 = gf->nodes[i]->src[1]; - struct ggml_tensor * src2 = gf->nodes[i]->src[2]; - struct ggml_tensor * dst = gf->nodes[i]; + const int64_t ne0 = dst ? dst->ne[0] : 0; + const int64_t ne1 = dst ? dst->ne[1] : 0; + const int64_t ne2 = dst ? dst->ne[2] : 0; + const int64_t ne3 = dst ? dst->ne[3] : 0; - if (ggml_is_empty(dst)) { - continue; - } + const uint64_t nb0 = dst ? dst->nb[0] : 0; + const uint64_t nb1 = dst ? dst->nb[1] : 0; + const uint64_t nb2 = dst ? dst->nb[2] : 0; + const uint64_t nb3 = dst ? dst->nb[3] : 0; - switch (dst->op) { - case GGML_OP_NONE: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_TRANSPOSE: - case GGML_OP_PERMUTE: - { - // noop -> next node - } continue; + const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; + const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; + + size_t offs_src0 = 0; + size_t offs_src1 = 0; + size_t offs_src2 = 0; + size_t offs_dst = 0; + + id id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil; + id id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil; + id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; + id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; + + //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); + //if (src0) { + // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, + // ggml_is_contiguous(src0), src0->name); + //} + //if (src1) { + // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, + // ggml_is_contiguous(src1), src1->name); + //} + //if (dst) { + // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, + // dst->name); + //} + + switch (dst->op) { + case GGML_OP_CONCAT: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; + + const int32_t dim = ((int32_t *) dst->op_params)[0]; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + [encoder setBytes:&dim length:sizeof(dim) atIndex:27]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ADD: + case GGML_OP_MUL: + case GGML_OP_DIV: + { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + + const size_t offs = 0; + + bool bcast_row = false; + + int64_t nb = ne00; // used by the "row" kernels + + id pipeline = nil; + + if (dst->op == GGML_OP_MUL && ggml_nelements(src1) == 1 && ggml_is_contiguous(src0)) { + float scale; + memcpy(&scale, src1->data, sizeof(float)); + //printf("Replacing op_mul with op_scale. scale = %g\n", (double)scale); + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; + + int64_t n = ggml_nelements(dst); + + if (n % 4 == 0) { + n /= 4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + break; + } + else if (ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst->src[1]) && ggml_is_contiguous(dst) && + dst->src[0]->ne[0] == dst->src[1]->ne[0] && dst->src[0]->ne[0] == dst->ne[0] && + dst->src[0]->ne[1] == dst->src[1]->ne[1] && dst->src[0]->ne[1] == dst->ne[1] && + dst->src[0]->ne[2] == dst->src[1]->ne[2] && dst->src[0]->ne[2] == dst->ne[2] && + dst->src[0]->ne[3] == dst->src[1]->ne[3] && ggml_nelements(dst)%4 == 0) { + + switch (dst->op) { + case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_4].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_4].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_4].pipeline; break; + default: GGML_ASSERT(false); + } + + int64_t n = ggml_nelements(dst)/4; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + break; + } + else if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(src0)); + + // src1 is a row + GGML_ASSERT(ne11 == 1); + + nb = ne00 / 4; + switch (dst->op) { + case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; + default: GGML_ABORT("fatal error"); + } + + bcast_row = true; + } else { + switch (dst->op) { + case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; + default: GGML_ABORT("fatal error"); + } + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + [encoder setBytes:&nb length:sizeof(nb) atIndex:28]; + + if (bcast_row) { + const int64_t n = ggml_nelements(dst)/4; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } else { + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } + } break; + case GGML_OP_MULTI_ADD: + { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(dstt == GGML_TYPE_F32); + GGML_ASSERT(ne02 == 1 && ne03 == 1); + GGML_ASSERT(nb0 == sizeof(float) && nb00 == sizeof(float)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + int n_expert = dst->op_params[0]; + GGML_ASSERT(n_expert >= 2); + + id pipeline = nil; + int64_t n = ne0*ne1; + if (ne0%4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD].pipeline; + } + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:2]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:3]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:4]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; + [encoder setBytes:&n_expert length:sizeof(n_expert) atIndex:6]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_REPEAT: + { + id pipeline; + + switch (src0t) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break; + case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break; + case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break; + default: GGML_ABORT("fatal error"); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ACC: + { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + GGML_ASSERT(dstt == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + const size_t pnb1 = ((int32_t *) dst->op_params)[0]; + const size_t pnb2 = ((int32_t *) dst->op_params)[1]; + const size_t pnb3 = ((int32_t *) dst->op_params)[2]; + const size_t offs = ((int32_t *) dst->op_params)[3]; + + const bool inplace = (bool) ((int32_t *) dst->op_params)[4]; + + if (!inplace) { + // run a separete kernel to cpy src->dst + // not sure how to avoid this + // TODO: make a simpler cpy_bytes kernel + + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } + + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8]; + [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9]; + [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24]; + [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25]; + [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; + [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); + + [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_SCALE: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + float scale; + memcpy(&scale, dst->op_params, sizeof(scale)); + + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + n /= 4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SOFTCAP: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + float scales[2]; + memcpy(scales, dst->op_params, sizeof(scales)); + + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + n /= 4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&scales[0] length:sizeof(float) atIndex:2]; + [encoder setBytes:&scales[1] length:sizeof(float) atIndex:3]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_CLAMP: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; + + float min; + float max; + memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float)); + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&min length:sizeof(min) atIndex:2]; + [encoder setBytes:&max length:sizeof(max) atIndex:3]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(node)) { + // we are not taking into account the strides, so for now require contiguous tensors + GGML_ASSERT(ggml_is_contiguous(src0)); + + case GGML_UNARY_OP_TANH: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_RELU: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_SIGMOID: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_GELU: + { + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_GELU_QUICK: + { + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_SILU: + { + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_SWIGLU: + { + int64_t n = ggml_nelements(dst); + GGML_ASSERT(ne0 == src0->ne[0]/2); + + id pipeline = nil; + + uint32_t n_per_row = ne0; + uint32_t stride = src0->nb[1]/sizeof(float); + + if (ne0 % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_4].pipeline; + n /= 4; + n_per_row /= 4; + stride /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&n_per_row length:sizeof(n_per_row) atIndex:2]; + [encoder setBytes:&stride length:sizeof(stride) atIndex:3]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; default: - { - } break; - } - - if (!ggml_metal_supports_op(ctx, dst)) { - GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst)); - GGML_ABORT("unsupported op"); - } - - if (should_capture) { - [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]]; - } - - const int64_t ne00 = src0 ? src0->ne[0] : 0; - const int64_t ne01 = src0 ? src0->ne[1] : 0; - const int64_t ne02 = src0 ? src0->ne[2] : 0; - const int64_t ne03 = src0 ? src0->ne[3] : 0; - - const uint64_t nb00 = src0 ? src0->nb[0] : 0; - const uint64_t nb01 = src0 ? src0->nb[1] : 0; - const uint64_t nb02 = src0 ? src0->nb[2] : 0; - const uint64_t nb03 = src0 ? src0->nb[3] : 0; - - const int64_t ne10 = src1 ? src1->ne[0] : 0; - const int64_t ne11 = src1 ? src1->ne[1] : 0; - const int64_t ne12 = src1 ? src1->ne[2] : 0; - const int64_t ne13 = src1 ? src1->ne[3] : 0; - - const uint64_t nb10 = src1 ? src1->nb[0] : 0; - const uint64_t nb11 = src1 ? src1->nb[1] : 0; - const uint64_t nb12 = src1 ? src1->nb[2] : 0; - const uint64_t nb13 = src1 ? src1->nb[3] : 0; - - const int64_t ne20 = src2 ? src2->ne[0] : 0; - const int64_t ne21 = src2 ? src2->ne[1] : 0; - const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22); - const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23); - - const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); - const uint64_t nb21 = src2 ? src2->nb[1] : 0; - const uint64_t nb22 = src2 ? src2->nb[2] : 0; - const uint64_t nb23 = src2 ? src2->nb[3] : 0; - - const int64_t ne0 = dst ? dst->ne[0] : 0; - const int64_t ne1 = dst ? dst->ne[1] : 0; - const int64_t ne2 = dst ? dst->ne[2] : 0; - const int64_t ne3 = dst ? dst->ne[3] : 0; - - const uint64_t nb0 = dst ? dst->nb[0] : 0; - const uint64_t nb1 = dst ? dst->nb[1] : 0; - const uint64_t nb2 = dst ? dst->nb[2] : 0; - const uint64_t nb3 = dst ? dst->nb[3] : 0; - - const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; - const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; - const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; - - id id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil; - id id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil; - id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; - id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; - - //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); - //if (src0) { - // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, - // ggml_is_contiguous(src0), src0->name); - //} - //if (src1) { - // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, - // ggml_is_contiguous(src1), src1->name); - //} - //if (dst) { - // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, - // dst->name); - //} - - switch (dst->op) { - case GGML_OP_CONCAT: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; - - const int32_t dim = ((int32_t *) dst->op_params)[0]; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&dim length:sizeof(dim) atIndex:27]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ADD: - case GGML_OP_MUL: - case GGML_OP_DIV: - { - GGML_ASSERT(src0t == GGML_TYPE_F32); - GGML_ASSERT(src1t == GGML_TYPE_F32); - - const size_t offs = 0; - - bool bcast_row = false; - - int64_t nb = ne00; // used by the "row" kernels - - id pipeline = nil; - - if (dst->op == GGML_OP_MUL && ggml_nelements(src1) == 1 && ggml_is_contiguous(src0)) { - float scale; - memcpy(&scale, src1->data, sizeof(float)); - //printf("Replacing op_mul with op_scale. scale = %g\n", (double)scale); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; - - int64_t n = ggml_nelements(dst); - - if (n % 4 == 0) { - n /= 4; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - break; - } - else if (ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst->src[1]) && ggml_is_contiguous(dst) && - dst->src[0]->ne[0] == dst->src[1]->ne[0] && dst->src[0]->ne[0] == dst->ne[0] && - dst->src[0]->ne[1] == dst->src[1]->ne[1] && dst->src[0]->ne[1] == dst->ne[1] && - dst->src[0]->ne[2] == dst->src[1]->ne[2] && dst->src[0]->ne[2] == dst->ne[2] && - dst->src[0]->ne[3] == dst->src[1]->ne[3] && ggml_nelements(dst)%4 == 0) { - - switch (dst->op) { - case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_4].pipeline; break; - case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_4].pipeline; break; - case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_4].pipeline; break; - default: GGML_ASSERT(false); - } - - int64_t n = ggml_nelements(dst)/4; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - break; - } - else if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { - GGML_ASSERT(ggml_is_contiguous(src0)); - - // src1 is a row - GGML_ASSERT(ne11 == 1); - - nb = ne00 / 4; - switch (dst->op) { - case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; - case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; - case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; - default: GGML_ABORT("fatal error"); - } - - bcast_row = true; - } else { - switch (dst->op) { - case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; - case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; - case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; - default: GGML_ABORT("fatal error"); - } - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:28]; - - if (bcast_row) { - const int64_t n = ggml_nelements(dst)/4; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } else { - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } - } break; - case GGML_OP_MULTI_ADD: - { - GGML_ASSERT(src0t == GGML_TYPE_F32); - GGML_ASSERT(dstt == GGML_TYPE_F32); - GGML_ASSERT(ne02 == 1 && ne03 == 1); - GGML_ASSERT(nb0 == sizeof(float) && nb00 == sizeof(float)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - int n_expert = dst->op_params[0]; - GGML_ASSERT(n_expert >= 2); - - id pipeline = nil; - int64_t n = ne0*ne1; - if (ne0%4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MULTI_ADD].pipeline; - } - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:2]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:3]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:4]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; - [encoder setBytes:&n_expert length:sizeof(n_expert) atIndex:6]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_REPEAT: - { - id pipeline; - - switch (src0t) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break; - case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break; - case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break; - default: GGML_ABORT("fatal error"); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ACC: - { - GGML_ASSERT(src0t == GGML_TYPE_F32); - GGML_ASSERT(src1t == GGML_TYPE_F32); - GGML_ASSERT(dstt == GGML_TYPE_F32); - - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - - const size_t pnb1 = ((int32_t *) dst->op_params)[0]; - const size_t pnb2 = ((int32_t *) dst->op_params)[1]; - const size_t pnb3 = ((int32_t *) dst->op_params)[2]; - const size_t offs = ((int32_t *) dst->op_params)[3]; - - const bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - - if (!inplace) { - // run a separete kernel to cpy src->dst - // not sure how to avoid this - // TODO: make a simpler cpy_bytes kernel - - const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } - - const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); - - [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_SCALE: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - float scale; - memcpy(&scale, dst->op_params, sizeof(scale)); - - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - n /= 4; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SOFTCAP: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - float scales[2]; - memcpy(scales, dst->op_params, sizeof(scales)); - - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - n /= 4; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP_4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFTCAP].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&scales[0] length:sizeof(float) atIndex:2]; - [encoder setBytes:&scales[1] length:sizeof(float) atIndex:3]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_CLAMP: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; - - float min; - float max; - memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float)); - memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float)); - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&min length:sizeof(min) atIndex:2]; - [encoder setBytes:&max length:sizeof(max) atIndex:3]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_UNARY: - switch (ggml_get_unary_op(gf->nodes[i])) { - // we are not taking into account the strides, so for now require contiguous tensors - GGML_ASSERT(ggml_is_contiguous(src0)); - - case GGML_UNARY_OP_TANH: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_RELU: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_SIGMOID: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_GELU: - { - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_GELU_QUICK: - { - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_SILU: - { - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_SWIGLU: - { - int64_t n = ggml_nelements(dst); - GGML_ASSERT(ne0 == src0->ne[0]/2); - - id pipeline = nil; - - uint32_t n_per_row = ne0; - uint32_t stride = src0->nb[1]/sizeof(float); - - if (ne0 % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_4].pipeline; - n /= 4; - n_per_row /= 4; - stride /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&n_per_row length:sizeof(n_per_row) atIndex:2]; - [encoder setBytes:&stride length:sizeof(stride) atIndex:3]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - default: - { - GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); - GGML_ABORT("fatal error"); - } - } break; - case GGML_OP_FUSED_MUL_UNARY: - { - int64_t n = ggml_nelements(dst); - enum ggml_unary_op op = (enum ggml_unary_op)dst->op_params[0]; - id pipeline = nil; - if (n % 4 == 0 && op != GGML_UNARY_OP_RELU) { - pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU_4].pipeline - : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU_4].pipeline; - n /= 4; - } else { - pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU].pipeline - : op == GGML_UNARY_OP_SILU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU].pipeline - : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_RELU].pipeline; - } - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SQR: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SUM_ROWS: - { - GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SOFT_MAX: - { - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); - - int nth = 32; // SIMD width - - id pipeline = nil; - - const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); - - if (ne00%4 == 0) { - while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - if (use_f16) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; - } - } else { - while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - if (use_f16) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; - } - } - - float scale; - float max_bias; - - memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); - memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); - - const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src0->ne[1]; - - const uint32_t n_head = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - if (id_src1) { - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; - [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_SOFT_CAP_MAX: - { - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); - - int nth = 32; // SIMD width - - id pipeline = nil; - - const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); - - if (ne00%4 == 0) { - while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - if (use_f16) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4].pipeline; - } - } else { - while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - if (use_f16) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32].pipeline; - } - } - - float scale; - float max_bias; - float s_before; - float s_after; - - memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); - memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); - memcpy(&s_before, ((int32_t *) dst->op_params) + 2, sizeof(s_before)); - memcpy(&s_after, ((int32_t *) dst->op_params) + 3, sizeof(s_after)); - - const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src0->ne[1]; - - const uint32_t n_head = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - if (id_src1) { - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; - [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; - [encoder setBytes:&s_before length:sizeof(s_before) atIndex:10]; - [encoder setBytes:&s_after length:sizeof(s_after ) atIndex:11]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:12]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_DIAG_MASK_INF: - { - const int n_past = ((int32_t *)(dst->op_params))[0]; - - id pipeline = nil; - - if (ne00%8 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; - - if (ne00%8 == 0) { - [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } - else { - [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } - } break; - case GGML_OP_MUL_MAT: - { - GGML_ASSERT(ne00 == ne10); - - GGML_ASSERT(ne12 % ne02 == 0); - GGML_ASSERT(ne13 % ne03 == 0); - - const uint r2 = ne12/ne02; - const uint r3 = ne13/ne03; - - // find the break-even point where the matrix-matrix kernel becomes more efficient compared - // to the matrix-vector kernel - int ne11_mm_min = 1; + { + GGML_METAL_LOG_WARN("%s: node %s, op = %8s not implemented\n", __func__, dst->name, ggml_op_name(dst->op)); + GGML_ABORT("fatal error"); + } + } break; + case GGML_OP_FUSED_MUL_UNARY: + { + int64_t n = ggml_nelements(dst); + enum ggml_unary_op op = (enum ggml_unary_op)dst->op_params[0]; + id pipeline = nil; + if (n % 4 == 0 && op != GGML_UNARY_OP_RELU) { + pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU_4].pipeline + : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU_4].pipeline; + n /= 4; + } else { + pipeline = op == GGML_UNARY_OP_GELU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_GELU].pipeline + : op == GGML_UNARY_OP_SILU ? ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_SILU].pipeline + : ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_RELU].pipeline; + } + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SQR: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SUM_ROWS: + { + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SOFT_MAX: + { + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); + + int nth = 32; // SIMD width + + id pipeline = nil; + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + + if (ne00%4 == 0) { + while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; + } + } else { + while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; + } + } + + float scale; + float max_bias; + + memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); + + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src0->ne[1]; + + const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_src1) { + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; + [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_SOFT_CAP_MAX: + { + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); + + int nth = 32; // SIMD width + + id pipeline = nil; + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + + if (ne00%4 == 0) { + while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32_4].pipeline; + } + } else { + while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F16].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_CAP_MAX_F32].pipeline; + } + } + + float scale; + float max_bias; + float s_before; + float s_after; + + memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); + memcpy(&s_before, ((int32_t *) dst->op_params) + 2, sizeof(s_before)); + memcpy(&s_after, ((int32_t *) dst->op_params) + 3, sizeof(s_after)); + + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src0->ne[1]; + + const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_src1) { + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; + [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; + [encoder setBytes:&s_before length:sizeof(s_before) atIndex:10]; + [encoder setBytes:&s_after length:sizeof(s_after ) atIndex:11]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:12]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_DIAG_MASK_INF: + { + const int n_past = ((int32_t *)(dst->op_params))[0]; + + id pipeline = nil; + + if (ne00%8 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; + + if (ne00%8 == 0) { + [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } + else { + [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } + } break; + case GGML_OP_MUL_MAT: + { + GGML_ASSERT(ne00 == ne10); + + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); + + const uint r2 = ne12/ne02; + const uint r3 = ne13/ne03; + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + int ne11_mm_min = 4; #if 0 - // the numbers below are measured on M2 Ultra for 7B and 13B models - // these numbers do not translate to other devices or model sizes - // TODO: need to find a better approach + // the numbers below are measured on M2 Ultra for 7B and 13B models + // these numbers do not translate to other devices or model sizes + // TODO: need to find a better approach if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) { switch (src0t) { case GGML_TYPE_F16: ne11_mm_min = 2; break; @@ -1976,11 +1943,11 @@ static enum ggml_status ggml_metal_graph_compute( // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && - !ggml_is_transposed(src0) && - !ggml_is_transposed(src1) && - src1t == GGML_TYPE_F32 && - ne00 % 32 == 0 && ne00 >= 64 && - (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) { + !ggml_is_transposed(src0) && + !ggml_is_transposed(src1) && + src1t == GGML_TYPE_F32 && + ne00 % 32 == 0 && ne00 >= 64 && + (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) { //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); // some Metal matrix data types require aligned pointers @@ -2305,9 +2272,9 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&r3 length:sizeof(r3) atIndex:18]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S|| - src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_Q6_0) { + src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || + src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S|| + src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_Q6_0) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) { @@ -2326,8 +2293,8 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K || - src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS|| - src0t == GGML_TYPE_IQ4_KSS) { + src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS|| + src0t == GGML_TYPE_IQ4_KSS) { const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float) : GGML_TYPE_IQ5_K ? 64*sizeof(float) : 32*sizeof(float); [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -2348,1179 +2315,1254 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } } - } break; - case GGML_OP_MUL_MAT_ID: - { - const int n_as = src0->ne[2]; - - // src2 = ids - const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); - - GGML_ASSERT(src2t == GGML_TYPE_I32); - - GGML_ASSERT(!ggml_is_transposed(src0)); - GGML_ASSERT(!ggml_is_transposed(src1)); - - GGML_ASSERT(src1t == GGML_TYPE_F32); - - // find the break-even point where the matrix-matrix kernel becomes more efficient compared - // to the matrix-vector kernel - // ne20 = n_used_experts - // ne21 = n_rows - const int dst_rows = ne20*ne21; - const int dst_rows_min = n_as; - const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4; - - // max size of the rowids array in the kernel shared buffer - GGML_ASSERT(dst_rows <= dst_rows_max); - - // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs - // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - // !!! - // TODO: for now, always use mat-vec kernels until we figure out how to improve the - // indirect matrix multiplication - // !!! - if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && - ne00 % 32 == 0 && ne00 >= 64 && - dst_rows > dst_rows_min) { - - // some Metal matrix data types require aligned pointers - // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) - switch (src0->type) { - case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; - case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; - default: break; - } - - id pipeline = nil; - - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break; - case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_0_F32 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break; - case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break; - case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break; - case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break; - case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break; - case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; - case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break; - case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break; - case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break; - case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32 ].pipeline; break; - case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32 ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; - case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; - case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32 ].pipeline; break; - case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KSS_F32].pipeline; break; - case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32 ].pipeline; break; - case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KS_F32 ].pipeline; break; - case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_K_F32 ].pipeline; break; - case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32 ].pipeline; break; - case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_K_F32 ].pipeline; break; - case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ6_K_F32 ].pipeline; break; - default: GGML_ABORT("MUL_MAT_ID not implemented"); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; - - [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; - } else { - int nth0 = 32; - int nth1 = 1; - int nrows = 1; - //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - - id pipeline = nil; - - // use custom matrix x vector kernel - switch (src0t) { - case GGML_TYPE_F32: - { - GGML_ASSERT(src1t == GGML_TYPE_F32); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; - } break; - case GGML_TYPE_F16: - { - GGML_ASSERT(src1t == GGML_TYPE_F32); - nth0 = 32; - nth1 = 1; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; - } break; - case GGML_TYPE_BF16: - { - GGML_ASSERT(src1t == GGML_TYPE_F32); - nth0 = 32; - nth1 = 1; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline; - } break; - case GGML_TYPE_Q4_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; - } break; - case GGML_TYPE_Q4_1: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; - } break; - case GGML_TYPE_Q5_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; - } break; - case GGML_TYPE_Q5_1: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; - } break; - case GGML_TYPE_Q6_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_0_F32].pipeline; - } break; - case GGML_TYPE_Q8_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; - } break; - case GGML_TYPE_Q2_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; - } break; - case GGML_TYPE_Q3_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; - } break; - case GGML_TYPE_Q4_K: - { - nth0 = 4; //1; - nth1 = 8; //32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; - } break; - case GGML_TYPE_Q5_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; - } break; - case GGML_TYPE_Q6_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; - } break; - case GGML_TYPE_IQ2_XXS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; - } break; - case GGML_TYPE_IQ2_XS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; - } break; - case GGML_TYPE_IQ3_XXS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline; - } break; - case GGML_TYPE_IQ3_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline; - } break; - case GGML_TYPE_IQ2_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline; - } break; - case GGML_TYPE_IQ1_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline; - } break; - case GGML_TYPE_IQ1_M: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline; - } break; - case GGML_TYPE_IQ1_BN: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32].pipeline; - } break; - case GGML_TYPE_IQ2_BN: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32].pipeline; - } break; - case GGML_TYPE_IQ4_NL: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; - } break; - case GGML_TYPE_IQ4_XS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; - } break; - case GGML_TYPE_IQ4_KS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32].pipeline; - } break; - case GGML_TYPE_IQ4_KSS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KSS_F32].pipeline; - } break; - case GGML_TYPE_IQ2_K: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_K_F32].pipeline; - } break; - case GGML_TYPE_IQ2_KS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_KS_F32].pipeline; - } break; - case GGML_TYPE_IQ3_K: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_K_F32].pipeline; - } break; - case GGML_TYPE_IQ4_K: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_K_F32].pipeline; - } break; - case GGML_TYPE_IQ5_K: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ5_K_F32].pipeline; - } break; - case GGML_TYPE_IQ6_K: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ6_K_F32].pipeline; - } break; - default: - { - GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); - GGML_ABORT("not implemented"); - } - }; - - if (ggml_is_quantized(src0t)) { - GGML_ASSERT(ne00 >= nth0*nth1); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22]; - - const int64_t _ne1 = 1; - const int tgz = dst_rows; - - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_Q6_0 || - src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) { - const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : src0t == GGML_TYPE_IQ3_K ? 32*sizeof(float) : 16*sizeof(float); - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { - const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { - const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K || - src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS|| - src0t == GGML_TYPE_IQ4_KSS) { - const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float) : GGML_TYPE_IQ5_K ? 64*sizeof(float) : 32*sizeof(float); - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q3_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else { - const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1 - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - } - } break; - case GGML_OP_GET_ROWS: - { - id pipeline = nil; - - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break; - case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_0 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break; - case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break; - case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break; - case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break; - case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break; - case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; - case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break; - case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break; - case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break; - case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break; - case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break; - case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN ].pipeline; break; - case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; - case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; - case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS ].pipeline; break; - case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KSS].pipeline; break; - case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K ].pipeline; break; - case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KS ].pipeline; break; - case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_K ].pipeline; break; - case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K ].pipeline; break; - case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_K ].pipeline; break; - case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ6_K ].pipeline; break; - case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; - default: GGML_ABORT("not implemented"); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5]; - [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; - } break; - case GGML_OP_RMS_NORM: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_1(src0)); - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - int nth = 32; // SIMD width - - while (nth < ne00/4 && nth < 1024) { - nth *= 2; - } - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - const int64_t nrows = ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_FUSED_RMS_NORM: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_1(src0)); - GGML_ASSERT(src1->ne[0] == src0->ne[0]); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_nrows(src1) == 1); - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - int nth = 32; // SIMD width - - while (nth < ne00/4 && nth < 1024) { - nth *= 2; - } - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; - [encoder setBytes:&eps length:sizeof( float) atIndex:5]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - const int64_t nrows = ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_GROUP_NORM: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous(src0)); - - float eps; - memcpy(&eps, dst->op_params + 1, sizeof(float)); - - const int32_t n_groups = ((int32_t *) dst->op_params)[0]; - - int nth = 32; // SIMD width - - //while (nth < ne00/4 && nth < 1024) { - // nth *= 2; - //} - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8]; - [encoder setBytes:&eps length:sizeof( float) atIndex:9]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_NORM: - { - GGML_ASSERT(ggml_is_contiguous_1(src0)); - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - const int nth = MIN(256, ne00); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0]; - - const int64_t nrows = ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ROPE: - { - GGML_ASSERT(ne10 == ne02); - - const int nth = MIN(1024, ne00); - - const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal - const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; - - float freq_base; - float freq_scale; - float ext_factor; - float attn_factor; - float beta_fast; - float beta_slow; - - memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - - const bool is_neox = mode & 2; - - id pipeline = nil; - - if (!is_neox) { - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - } else { - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - if (id_src2 != nil) { - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; - [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22]; - [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; - [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; - [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; - [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; - [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; - [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_IM2COL: - { - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); - - const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; - - const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; - - const int32_t N = src1->ne[is_2D ? 3 : 2]; - const int32_t IC = src1->ne[is_2D ? 2 : 1]; - const int32_t IH = is_2D ? src1->ne[1] : 1; - const int32_t IW = src1->ne[0]; - - const int32_t KH = is_2D ? src0->ne[1] : 1; - const int32_t KW = src0->ne[0]; - - const int32_t OH = is_2D ? dst->ne[2] : 1; - const int32_t OW = dst->ne[1]; - - const int32_t CHW = IC * KH * KW; - - const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; - const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; - - id pipeline = nil; - - switch (dst->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2]; - [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3]; - [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4]; - [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5]; - [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6]; - [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7]; - [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8]; - [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9]; - [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10]; - [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11]; - [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12]; - - [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; - } break; - case GGML_OP_UPSCALE: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - const float sf0 = (float)ne0/src0->ne[0]; - const float sf1 = (float)ne1/src0->ne[1]; - const float sf2 = (float)ne2/src0->ne[2]; - const float sf3 = (float)ne3/src0->ne[3]; - - const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18]; - [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19]; - [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20]; - [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_PAD: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ARANGE: - { - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - float start; - float step; - - memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float)); - memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1]; - [encoder setBytes:&start length:sizeof(start) atIndex:2]; - [encoder setBytes:&step length:sizeof(step) atIndex:3]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_TIMESTEP_EMBEDDING: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - const int dim = dst->op_params[0]; - const int max_period = dst->op_params[1]; - - const int half = dim / 2; - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2]; - [encoder setBytes:&dim length:sizeof(dim) atIndex:3]; - [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4]; - - const int nth = MIN(1024, half); - - [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ARGSORT: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_I32); - - const int nrows = ggml_nrows(src0); - - enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; - - // bitonic sort requires the number of elements to be power of 2 - int64_t ne00_padded = 1; - while (ne00_padded < ne00) { - ne00_padded *= 2; - } - - // Metal kernels require the buffer size to be multiple of 16 bytes - // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength - const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16); - - id pipeline = nil; - - switch (order) { - case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; - case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3]; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)]; - } break; - case GGML_OP_LEAKY_RELU: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - float slope; - memcpy(&slope, dst->op_params, sizeof(float)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&slope length:sizeof(slope) atIndex:2]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_FLASH_ATTN_EXT: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ne11 % 32 == 0); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - GGML_ASSERT(ggml_are_same_shape (src1, src2)); - - struct ggml_tensor * src3 = gf->nodes[i]->src[3]; - - size_t offs_src3 = 0; - - id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; - - GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); - GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && - "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); - - const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); - //const int64_t ne31 = src3 ? src3->ne[1] : 0; - const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); - const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); - - const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); - const uint64_t nb31 = src3 ? src3->nb[1] : 0; - const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); - const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); - - const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); - - float scale; - float max_bias; - float softcap; - - memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); - memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); - memcpy(&softcap, ((int32_t *) dst->op_params) + 2, sizeof(softcap)); - if (softcap != 0.0f) { - scale /= softcap; - } - - const uint32_t n_head = src0->ne[2]; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - id pipeline = nil; - - bool use_vec_kernel = false; - - if (ne01 >= 4 || (ne00%128 != 0)) { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; - default: - { - GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } else { - use_vec_kernel = true; - - switch (ne00) { - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; - //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; - default: - { - GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - if (id_src3) { - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22]; - [encoder setBytes:&scale length:sizeof( float) atIndex:23]; - [encoder setBytes:&max_bias length:sizeof( float) atIndex:24]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:26]; - [encoder setBytes:&softcap length:sizeof(softcap) atIndex:27]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:28]; - - if (!use_vec_kernel) { - // half8x8 kernel - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - - GGML_ASSERT(nqptg <= 32); - GGML_ASSERT(nqptg % 8 == 0); - GGML_ASSERT(ncpsg % 32 == 0); - - int64_t nsgmax = 2; - - while (true) { - const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2); - if (smem > ctx->device.maxThreadgroupMemoryLength) { - break; - } - nsgmax *= 2; - } - nsgmax /= 2; - - // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; - - const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); - - //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); - GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); - - [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; - } else { - // half1x4 kernel - const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - - GGML_ASSERT(nqptg <= 32); - GGML_ASSERT(nqptg % 1 == 0); - GGML_ASSERT(ncpsg % 32 == 0); - - // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); - - int64_t nsg = 1; - while (nsg <= nsgt) { - nsg *= 2; - } - nsg /= 2; - - const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); - - //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); - GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; - } - } break; - case GGML_OP_DUP: - case GGML_OP_CPY: - case GGML_OP_CONT: - { - GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); - - int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); - - id pipeline = nil; - - switch (src0t) { - case GGML_TYPE_F32: - { - GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); - - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; - case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - case GGML_TYPE_F16: - { - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - default: GGML_ABORT("not implemented"); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - default: - { - GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); - GGML_ABORT("fatal error"); + } break; + case GGML_OP_MUL_MAT_ID: + { + const int n_as = src0->ne[2]; + + // src2 = ids + const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); + + GGML_ASSERT(src2t == GGML_TYPE_I32); + + GGML_ASSERT(!ggml_is_transposed(src0)); + GGML_ASSERT(!ggml_is_transposed(src1)); + + GGML_ASSERT(src1t == GGML_TYPE_F32); + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + // ne20 = n_used_experts + // ne21 = n_rows + const int dst_rows = ne20*ne21; + const int dst_rows_min = n_as; + //const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength/2 - 8192)/4; + const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 8192)/4; + + // max size of the rowids array in the kernel shared buffer + //GGML_ASSERT(dst_rows <= dst_rows_max); + + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs + // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel + // !!! + // TODO: for now, always use mat-vec kernels until we figure out how to improve the + // indirect matrix multiplication + // !!! + if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && + ne00 % 32 == 0 && ne00 >= 64 && + dst_rows > dst_rows_min && + dst_rows <= dst_rows_max) { + + // some Metal matrix data types require aligned pointers + // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) + switch (src0->type) { + case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; + case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; + default: break; } - } - if (should_capture) { - [encoder popDebugGroup]; + id pipeline = nil; + + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break; + case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_0_F32 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; + case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break; + case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break; + case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_BN_F32 ].pipeline; break; + case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_BN_F32 ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; + case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KS_F32 ].pipeline; break; + case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_KSS_F32].pipeline; break; + case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_K_F32 ].pipeline; break; + case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_KS_F32 ].pipeline; break; + case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_K_F32 ].pipeline; break; + case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_K_F32 ].pipeline; break; + case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ5_K_F32 ].pipeline; break; + case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ6_K_F32 ].pipeline; break; + default: GGML_ABORT("MUL_MAT_ID not implemented"); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; + [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; + + [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(1, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + + } else { + int nth0 = 32; + int nth1 = 1; + int nrows = 1; + + id pipeline = nil; + + // use custom matrix x vector kernel + switch (src0t) { + case GGML_TYPE_F32: + { + GGML_ASSERT(src1t == GGML_TYPE_F32); + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(src1t == GGML_TYPE_F32); + nth0 = 32; + nth1 = 1; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; + } break; + case GGML_TYPE_BF16: + { + GGML_ASSERT(src1t == GGML_TYPE_F32); + nth0 = 32; + nth1 = 1; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline; + } break; + case GGML_TYPE_Q4_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; + } break; + case GGML_TYPE_Q4_1: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; + } break; + case GGML_TYPE_Q5_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; + } break; + case GGML_TYPE_Q5_1: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; + } break; + case GGML_TYPE_Q6_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_0_F32].pipeline; + } break; + case GGML_TYPE_Q8_0: + { + nth0 = 32; + nth1 = 2; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; + } break; + case GGML_TYPE_Q2_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; + } break; + case GGML_TYPE_Q3_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; + } break; + case GGML_TYPE_Q4_K: + { + nth0 = 4; //1; + nth1 = 8; //32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; + } break; + case GGML_TYPE_Q5_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; + } break; + case GGML_TYPE_Q6_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; + } break; + case GGML_TYPE_IQ2_XXS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; + } break; + case GGML_TYPE_IQ2_XS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; + } break; + case GGML_TYPE_IQ3_XXS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline; + } break; + case GGML_TYPE_IQ3_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline; + } break; + case GGML_TYPE_IQ2_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline; + } break; + case GGML_TYPE_IQ1_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline; + } break; + case GGML_TYPE_IQ1_M: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline; + } break; + case GGML_TYPE_IQ1_BN: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_BN_F32].pipeline; + } break; + case GGML_TYPE_IQ2_BN: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_BN_F32].pipeline; + } break; + case GGML_TYPE_IQ4_NL: + { + nth0 = 32; + nth1 = 2; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; + } break; + case GGML_TYPE_IQ4_XS: + { + nth0 = 32; + nth1 = 2; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; + } break; + case GGML_TYPE_IQ4_KS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KS_F32].pipeline; + } break; + case GGML_TYPE_IQ4_KSS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_KSS_F32].pipeline; + } break; + case GGML_TYPE_IQ2_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_K_F32].pipeline; + } break; + case GGML_TYPE_IQ2_KS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_KS_F32].pipeline; + } break; + case GGML_TYPE_IQ3_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_K_F32].pipeline; + } break; + case GGML_TYPE_IQ4_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_K_F32].pipeline; + } break; + case GGML_TYPE_IQ5_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ5_K_F32].pipeline; + } break; + case GGML_TYPE_IQ6_K: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ6_K_F32].pipeline; + } break; + default: + { + GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); + GGML_ABORT("not implemented"); + } + }; + + if (ggml_is_quantized(src0t)) { + GGML_ASSERT(ne00 >= nth0*nth1); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; + [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22]; + + const int64_t _ne1 = 1; + const int tgz = dst_rows; + + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || + src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || + src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_Q6_0 || + src0t == GGML_TYPE_IQ1_BN|| src0t == GGML_TYPE_IQ2_BN|| src0t == GGML_TYPE_IQ2_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_IQ2_KS || src0t == GGML_TYPE_IQ2_K || src0t == GGML_TYPE_IQ3_K) { + const int mem_size = src0t == GGML_TYPE_IQ2_KS ? 64*sizeof(float) : src0t == GGML_TYPE_IQ3_K ? 32*sizeof(float) : 16*sizeof(float); + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { + const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { + const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS || src0t == GGML_TYPE_IQ4_K || + src0t == GGML_TYPE_IQ5_K || src0t == GGML_TYPE_IQ6_K || src0t == GGML_TYPE_IQ4_KS|| + src0t == GGML_TYPE_IQ4_KSS) { + const int mem_size = src0t == GGML_TYPE_IQ6_K ? 128*sizeof(float) : GGML_TYPE_IQ5_K ? 64*sizeof(float) : 32*sizeof(float); + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q3_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q5_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q6_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else { + const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1 + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + } + } break; + case GGML_OP_GET_ROWS: + { + id pipeline = nil; + + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break; + case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_0 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; + case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break; + case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break; + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break; + case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_BN ].pipeline; break; + case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_BN ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; + case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KS ].pipeline; break; + case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_KSS].pipeline; break; + case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_K ].pipeline; break; + case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_KS ].pipeline; break; + case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_K ].pipeline; break; + case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_K ].pipeline; break; + case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ5_K ].pipeline; break; + case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ6_K ].pipeline; break; + case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; + default: GGML_ABORT("not implemented"); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + } break; + case GGML_OP_RMS_NORM: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < 1024) { + nth *= 2; + } + + nth = MIN(nth, ne00/4); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_FUSED_RMS_NORM: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(src1->ne[0] == src0->ne[0]); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_nrows(src1) == 1); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < 1024) { + nth *= 2; + } + + nth = MIN(nth, ne00/4); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; + [encoder setBytes:&eps length:sizeof( float) atIndex:5]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_GROUP_NORM: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous(src0)); + + float eps; + memcpy(&eps, dst->op_params + 1, sizeof(float)); + + const int32_t n_groups = ((int32_t *) dst->op_params)[0]; + + int nth = 32; // SIMD width + + //while (nth < ne00/4 && nth < 1024) { + // nth *= 2; + //} + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8]; + [encoder setBytes:&eps length:sizeof( float) atIndex:9]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_NORM: + { + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + const int nth = MIN(256, ne00); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ROPE: + { + GGML_ASSERT(ne10 == ne02); + + const int nth = MIN(1024, ne00); + + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + + const bool is_neox = mode & 2; + + id pipeline = nil; + + if (!is_neox) { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } else { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + if (id_src2 != nil) { + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; + [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; + [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22]; + [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; + [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; + [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; + [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; + [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; + [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_IM2COL: + { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int32_t N = src1->ne[is_2D ? 3 : 2]; + const int32_t IC = src1->ne[is_2D ? 2 : 1]; + const int32_t IH = is_2D ? src1->ne[1] : 1; + const int32_t IW = src1->ne[0]; + + const int32_t KH = is_2D ? src0->ne[1] : 1; + const int32_t KW = src0->ne[0]; + + const int32_t OH = is_2D ? dst->ne[2] : 1; + const int32_t OW = dst->ne[1]; + + const int32_t CHW = IC * KH * KW; + + const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; + const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; + + id pipeline = nil; + + switch (dst->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2]; + [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3]; + [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4]; + [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5]; + [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6]; + [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7]; + [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8]; + [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9]; + [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10]; + [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11]; + [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12]; + + [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; + } break; + case GGML_OP_UPSCALE: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + const float sf0 = (float)ne0/src0->ne[0]; + const float sf1 = (float)ne1/src0->ne[1]; + const float sf2 = (float)ne2/src0->ne[2]; + const float sf3 = (float)ne3/src0->ne[3]; + + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18]; + [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19]; + [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20]; + [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_PAD: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ARANGE: + { + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + float start; + float step; + + memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1]; + [encoder setBytes:&start length:sizeof(start) atIndex:2]; + [encoder setBytes:&step length:sizeof(step) atIndex:3]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_TIMESTEP_EMBEDDING: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + const int dim = dst->op_params[0]; + const int max_period = dst->op_params[1]; + + const int half = dim / 2; + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2]; + [encoder setBytes:&dim length:sizeof(dim) atIndex:3]; + [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4]; + + const int nth = MIN(1024, half); + + [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ARGSORT: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + + const int nrows = ggml_nrows(src0); + + enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; + + // bitonic sort requires the number of elements to be power of 2 + int64_t ne00_padded = 1; + while (ne00_padded < ne00) { + ne00_padded *= 2; + } + + // Metal kernels require the buffer size to be multiple of 16 bytes + // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength + const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16); + + id pipeline = nil; + + switch (order) { + case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; + case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3]; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)]; + } break; + case GGML_OP_LEAKY_RELU: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + float slope; + memcpy(&slope, dst->op_params, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&slope length:sizeof(slope) atIndex:2]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_FLASH_ATTN_EXT: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ne11 % 32 == 0); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_are_same_shape (src1, src2)); + + struct ggml_tensor * src3 = node->src[3]; + + size_t offs_src3 = 0; + + id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; + + GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); + GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && + "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + + const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); + //const int64_t ne31 = src3 ? src3->ne[1] : 0; + const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); + const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); + + const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); + const uint64_t nb31 = src3 ? src3->nb[1] : 0; + const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); + const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); + + const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); + + float scale; + float max_bias; + float softcap; + + memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); + memcpy(&softcap, ((int32_t *) dst->op_params) + 2, sizeof(softcap)); + if (softcap != 0.0f) { + scale /= softcap; + } + + const uint32_t n_head = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + id pipeline = nil; + + bool use_vec_kernel = false; + + if (ne01 >= 4 || (ne00%128 != 0)) { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } else { + use_vec_kernel = true; + + switch (ne00) { + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; + //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + if (id_src3) { + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22]; + [encoder setBytes:&scale length:sizeof( float) atIndex:23]; + [encoder setBytes:&max_bias length:sizeof( float) atIndex:24]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:26]; + [encoder setBytes:&softcap length:sizeof(softcap) atIndex:27]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:28]; + + if (!use_vec_kernel) { + // half8x8 kernel + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + int64_t nsgmax = 2; + + while (true) { + const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2); + if (smem > ctx->device.maxThreadgroupMemoryLength) { + break; + } + nsgmax *= 2; + } + nsgmax /= 2; + + // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; + + const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else { + // half1x4 kernel + const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 1 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + + int64_t nsg = 1; + while (nsg <= nsgt) { + nsg *= 2; + } + nsg /= 2; + + const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } + } break; + case GGML_OP_DUP: + case GGML_OP_CPY: + case GGML_OP_CONT: + { + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + + int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); + + id pipeline = nil; + + switch (src0t) { + case GGML_TYPE_F32: + { + GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); + + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; + case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_F16: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + default: GGML_ABORT("not implemented"); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + default: + { + GGML_METAL_LOG_ERROR("%s: error: node %s, op = %8s not implemented\n", __func__, dst->name, ggml_op_name(dst->op)); + GGML_ABORT("fatal error"); + } + } + +} + +static enum ggml_status ggml_metal_graph_compute( + struct ggml_backend_metal_context * ctx, + struct ggml_cgraph * gf) { + + // number of nodes encoded by the main thread (empirically determined) + const int n_main = 128; + + // number of threads in addition to the main thread + const int n_cb = ctx->n_cb; + + @autoreleasepool { + ctx->gf = gf; + + ctx->n_nodes_0 = MIN(n_main, gf->n_nodes); + ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0; + + ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb; + + const bool should_capture = ctx->capture_next_compute; + if (should_capture) { + ctx->capture_next_compute = false; + + if (!ctx->capture_started) { + // create capture scope + ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device]; //ctx_dev->mtl_device]; + + MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; + descriptor.captureObject = ctx->capture_scope; + descriptor.destination = MTLCaptureDestinationGPUTraceDocument; + descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]]; + + NSError * error = nil; + if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { + printf("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]); + } else { + [ctx->capture_scope beginScope]; + ctx->capture_started = true; + } } } - [encoder endEncoding]; + // the main thread commits the first few commands immediately + // command_buffer[n_cb] + { + id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; + ctx->command_buffers[n_cb] = command_buffer; - if (cb_idx < 2 || ctx->abort_callback == NULL) { - [command_buffer commit]; + [command_buffer enqueue]; + ctx->encode_async(n_cb); } - }); - // Wait for completion and check status of each command buffer - // needed to detect if the device ran out-of-memory for example (#1881) + // prepare the rest of the command buffers asynchronously + // command_buffer[0.. n_cb) + for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { + id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; + ctx->command_buffers[cb_idx] = command_buffer; - for (int i = 0; i < n_cb; ++i) { - id command_buffer = command_buffers[i]; - [command_buffer waitUntilCompleted]; + // always enqueue the first two command buffers + // enqueue all of the command buffers if we don't need to abort + if (cb_idx < 2 || ctx->abort_callback == NULL) { + [command_buffer enqueue]; + } + } - MTLCommandBufferStatus status = [command_buffer status]; - if (status != MTLCommandBufferStatusCompleted) { - GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); - if (status == MTLCommandBufferStatusError) { - NSString * error_code = [command_buffer error].localizedDescription; - GGML_METAL_LOG_INFO("error: %s\n", [error_code UTF8String]); + dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async); + + // wait for completion and check status of each command buffer + // needed to detect if the device ran out-of-memory for example (#1881) + { + id command_buffer = ctx->command_buffers[n_cb]; + [command_buffer waitUntilCompleted]; + + MTLCommandBufferStatus status = [command_buffer status]; + if (status != MTLCommandBufferStatusCompleted) { + printf("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); + if (status == MTLCommandBufferStatusError) { + printf("error: %s\n", [[command_buffer error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; + } + } + for (int i = 0; i < n_cb; ++i) { + id command_buffer = ctx->command_buffers[i]; + [command_buffer waitUntilCompleted]; + + MTLCommandBufferStatus status = [command_buffer status]; + if (status != MTLCommandBufferStatusCompleted) { + printf("%s: command buffer %d failed with status %lu\n", __func__, i, status); + if (status == MTLCommandBufferStatusError) { + printf("error: %s\n", [[command_buffer error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; } - return GGML_STATUS_FAILED; + id next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil); + if (!next_buffer) { + continue; + } + + const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); + if (next_queued) { + continue; + } + + if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { + printf("%s: command buffer %d aborted", __func__, i); + return GGML_STATUS_ABORTED; + } + + [next_buffer commit]; } - id next_buffer = (i + 1 < n_cb ? command_buffers[i + 1] : nil); - if (!next_buffer) { - continue; + if (!should_capture && ctx->capture_started) { + [ctx->capture_scope endScope]; + [[MTLCaptureManager sharedCaptureManager] stopCapture]; } - - bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); - if (next_queued) { - continue; - } - - if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { - GGML_METAL_LOG_INFO("%s: command buffer %d aborted", __func__, i); - return GGML_STATUS_ABORTED; - } - - [next_buffer commit]; } - if (should_capture) { - [[MTLCaptureManager sharedCaptureManager] stopCapture]; - } - - } return GGML_STATUS_SUCCESS; } @@ -3905,8 +3947,60 @@ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { GGML_ASSERT(ggml_backend_is_metal(backend)); struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; + if (ctx->n_cb != n_cb) { + ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS); + + if (ctx->n_cb > 2) { + GGML_METAL_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb); + } + } + + if (ctx->encode_async) { + Block_release(ctx->encode_async); + } + + ctx->encode_async = Block_copy(^(size_t iter) { + const int cb_idx = iter; + const int n_cb_l = ctx->n_cb; + + const int n_nodes_0 = ctx->n_nodes_0; + const int n_nodes_1 = ctx->n_nodes_1; + + const int n_nodes_per_cb = ctx->n_nodes_per_cb; + + id command_buffer = ctx->command_buffers[cb_idx]; + id encoder = [command_buffer computeCommandEncoder]; + + int node_start = 0; + int node_end = n_nodes_0; + + if (cb_idx < n_cb_l) { + node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb); + node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1)); + } + + const bool should_capture = ctx->capture_next_compute; + + for (int idx = node_start; idx < node_end; ++idx) { + struct ggml_tensor * node = ctx->gf->nodes[idx]; + if (should_capture) { + [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(node) encoding:NSUTF8StringEncoding]]; + } + + ggml_metal_encode_node(ctx, node, encoder); + + if (should_capture) { + [encoder popDebugGroup]; + } + } + + [encoder endEncoding]; + + if (cb_idx < 2 || ctx->abort_callback == NULL) { + [command_buffer commit]; + } + }); - ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS); } void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) { diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 89cd412a..a67ec336 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1652,13 +1652,18 @@ void kernel_mul_mv_q8_0_f32_impl( yl[i] = yb[i]; } + device const block_q8_0 * xr = x + ib; + for (int row = 0; row < nr; row++) { - device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; + //device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; + device const int8_t * qs = xr->qs + NB_Q8_0*il; float sumq = 0.f; for (int iq = 0; iq < NB_Q8_0; ++iq) { sumq += qs[iq] * yl[iq]; } - sumf[row] += sumq*x[ib+row*nb].d; + //sumf[row] += sumq*x[ib+row*nb].d; + sumf[row] += sumq*xr->d; + xr += nb; } yb += NB_Q8_0 * nw; @@ -7218,10 +7223,16 @@ void dequantize_q6_0(device const block_q6_0 *xb, short il, thread type4x4 & reg template void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { device const int8_t * qs = ((device const int8_t *)xb->qs); - const half d = xb->d; - - for (int i = 0; i < 16; i++) { - reg[i/4][i%4] = (qs[i + 16*il] * d); + if constexpr (is_same_v) { + const half d = xb->d; + for (int i = 0; i < 16; i++) { + reg[i/4][i%4] = (half)qs[i + 16*il] * d; + } + } else { + const float d = xb->d; + for (int i = 0; i < 16; i++) { + reg[i/4][i%4] = qs[i + 16*il] * d; + } } } @@ -8246,39 +8257,6 @@ kernel void kernel_mul_mm_id( uint ntg = ntg3.x * ntg3.y * ntg3.z; uint n = nei0*nei1; - //uint npt = (n + ntg - 1) / ntg; - //uint first = tiitg * npt; - //uint last = first + npt <= n ? first + npt : n; - - //uint nhave = 0; - //for (uint i = first; i < last; ++i) { - // uint ii0 = i % nei0; - // uint ii1 = i / nei0; - // int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; - // if (id == i02) ++nhave; - //} - //threadgroup uint * nums = (threadgroup uint *)shared_memory; - //nums[tiitg] = nhave; - //threadgroup_barrier(mem_flags::mem_threadgroup); - - //uint nprev = 0; - //for (uint i = 0; i < tiitg; ++i) nprev += nums[i]; - //int64_t _ne1 = nprev; - //for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i]; - - //threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); - //for (uint i = first; i < last; ++i) { - // uint ii0 = i % nei0; - // uint ii1 = i / nei0; - // int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; - // if (id == i02) rowids[nprev++] = ushort2(ii0, ii1); - //} - - //threadgroup_barrier(mem_flags::mem_threadgroup); - - // - // The following is slightly faster than the commented out version above - // uint nhave = 0; for (uint i = tiitg; i < n; i += ntg) { uint ii0 = i % nei0; @@ -8290,10 +8268,24 @@ kernel void kernel_mul_mm_id( nums[tiitg] = nhave; threadgroup_barrier(mem_flags::mem_threadgroup); - uint nprev = 0; - for (uint i = 0; i < tiitg; ++i) nprev += nums[i]; - int64_t _ne1 = nprev; - for (uint i = tiitg; i < ntg; ++i) _ne1 += nums[i]; + uint stride = 1; + while (stride <= ntg/2) { + uint index = (tiitg+1)*stride*2 - 1; // index - stride = 2*tiitg*stride + stride - 1; + if (index < ntg) nums[index] += nums[index-stride]; + stride <<= 1; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + stride = ntg/2; + while (stride > 0) { + uint index = (tiitg+1)*stride*2 - 1; + if (index+stride < ntg) nums[index+stride] += nums[index]; + stride >>= 1; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + uint _ne1 = nums[ntg-1]; + if (!_ne1) return; + + uint nprev = tiitg > 0 ? nums[tiitg-1] : 0; threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); for (uint i = tiitg; i < n; i += ntg) { @@ -8304,47 +8296,37 @@ kernel void kernel_mul_mm_id( } threadgroup_barrier(mem_flags::mem_threadgroup); - // This is the original version that is ridiculously slow. - //// row indices - //threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); + uint nstep = (_ne1 + BLOCK_SIZE_N - 1)/BLOCK_SIZE_N; - //// TODO: parallelize this loop - //int64_t _ne1 = 0; - //for (ushort ii1 = 0; ii1 < nei1; ii1++) { - // for (ushort ii0 = 0; ii0 < nei0; ii0++) { - // int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; - // if (id == i02) { - // //if (tiitg == 0) { - // rowids[_ne1] = ushort2(ii0, ii1); - // //} - // _ne1++; - // } - // } - //} + for (uint istep = 0; istep < nstep; ++istep) { - //threadgroup_barrier(mem_flags::mem_threadgroup); + uint first = BLOCK_SIZE_N*istep; + uint last = first + BLOCK_SIZE_N < _ne1 ? first + BLOCK_SIZE_N : _ne1; + int64_t this_ne1 = last - first; + threadgroup ushort2 * this_rowids = rowids + istep*BLOCK_SIZE_N; - kernel_mul_mm_id_impl( - src0, - src1, - rowids, - dst, - ne00, - ne02, - nb01, - nb02, - ne11, - ne12, - nb10, - nb11, - nb12, - ne0, - _ne1, - ne0*ne1, - shared_memory, - tgpig, - tiitg, - sgitg); + kernel_mul_mm_id_impl( + src0, + src1, + this_rowids, + dst, + ne00, + ne02, + nb01, + nb02, + ne11, + ne12, + nb10, + nb11, + nb12, + ne0, + this_ne1, + ne0*ne1, + shared_memory, + tgpig, + tiitg, + sgitg); + } } #define QK_NL 16 From 2ee6263e24ee5f613c46696589ba5d4362017ba7 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 3 Apr 2025 15:50:53 +0200 Subject: [PATCH 070/132] Fix GCC compilation errors on ARM (#309) * Fix GCC compilation errors on ARM * One more --------- Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_mul_mat.cpp | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 1c8a991d..9a34270b 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -12735,7 +12735,7 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI int nb = n / 32; GGML_ASSERT(nb%4 == 0); uint8x16_t qx[8]; - int32x4_t acc[nrc_y] = {}; + float32x4_t acc[nrc_y] = {}; auto ms = vdup_n_u16(0x8000); auto mask = vdupq_n_s8(0x03); float d8[4*nrc_y]; @@ -14140,7 +14140,7 @@ void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& i auto scale2 = vmulq_f32(scale2_x, scale_y); info.store(ix+0, iy, vmulq_f32(scale1, vcvtq_f32_s32(acc[2*iy+0]))); info.store(ix+4, iy, vmulq_f32(scale2, vcvtq_f32_s32(acc[2*iy+1]))); - acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f); + acc[2*iy+0] = acc[2*iy+1] = vdupq_n_s32(0.f); } } } @@ -14823,11 +14823,11 @@ inline float32x4_t v_tanh(float32x4_t x) { return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask))); //return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); } -inline float32x4_t v_tanh(float16x8_t x) { - auto val1 = v_tanh(vcvt_f32_f16(vget_low_f16(x))); - auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x))); - return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2)); -} +//inline float32x4_t v_tanh(float16x8_t x) { +// auto val1 = v_tanh(vcvt_f32_f16(vget_low_f16(x))); +// auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x))); +// return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2)); +//} inline float32x4_t v_silu(float32x4_t x) { const float32x4_t one = vdupq_n_f32(1.0f); const float32x4_t zero = vdupq_n_f32(0.0f); @@ -15671,7 +15671,9 @@ struct HelperQ60 final : public BaseHelper { auto dl = (const block_q6_0 *)Base::lblock(l1) + j/QK6_0; #ifdef __aarch64__ // TODO - auto vd = F16::set1(*(const float16_t *)&dl->d); + const float16_t * d16 = (const float16_t *)&dl->d; + auto vd = F16::set1(d16[0]); + //auto vd = F16::set1(*(const float16_t *)&dl->d); auto qh8 = vld1_u8(dl->qh); auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8); auto qs = vld1q_u8(dl->qs); @@ -15819,7 +15821,7 @@ struct FlashMS { return vmaxvq_f32(vmax); } inline float load_apply_mask_and_scale(int j, float32x4_t * vk, const char * mask) { - auto vzero = vdupq_n_f32(0); + auto vzero = vdupq_n_f16(0); auto vinf = vdupq_n_f32(-INFINITY); for (int l = 0; l < k_step/8; ++l) { auto vm = vceqq_f16(vzero, vld1q_f16((const float16_t *)mask + 8*l)); @@ -15827,9 +15829,9 @@ struct FlashMS { auto vm2 = vzip2q_u16(vm, vm); auto kq = vld1q_f32_x2(cache + k_step*j + 8*l); vk[2*l+0] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[0]), vm1), - vbicq_u32(vinf, vm1))); + vbicq_u32(vreinterpretq_u32_f32(vinf), vm1))); vk[2*l+1] = vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(kq.val[1]), vm2), - vbicq_u32(vinf, vm2))); + vbicq_u32(vreinterpretq_u32_f32(vinf), vm2))); } float32x4_t vmax = vdupq_n_f32(-INFINITY); auto vscale32 = vcvt_f32_f16(vget_low_f16(vscale)); From 073eda985ea16098cff5b708c87ce70d41784ea2 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 3 Apr 2025 17:54:25 +0200 Subject: [PATCH 071/132] Metal: FA and FlashMLA (#310) * Metal: WIP to update Metal FA implementation Dk=192, Dv=128 works, but not Dk = 576, Dv = 512 * Metal FA: go to float * WIP * Metal FA: MLA options now all work --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-metal.m | 558 ++++++++++++---- ggml/src/ggml-metal.metal | 1261 ++++++++++++++++++++++++------------- 2 files changed, 1248 insertions(+), 571 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index aa50d448..501fe5a2 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -225,6 +225,39 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, @@ -276,9 +309,33 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, - //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, - //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F16_F16, @@ -290,7 +347,8 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0, GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, - GGML_METAL_KERNEL_TYPE_CONCAT, + GGML_METAL_KERNEL_TYPE_CONCAT_F32, + GGML_METAL_KERNEL_TYPE_CONCAT_F16, GGML_METAL_KERNEL_TYPE_SQR, GGML_METAL_KERNEL_TYPE_SUM_ROWS, @@ -793,6 +851,39 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32, mul_mm_iq4_k_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32, mul_mm_iq5_k_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32, mul_mm_iq6_k_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F16, mul_mm_f32_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F16, mul_mm_f16_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F16, mul_mm_bf16_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F16, mul_mm_q4_0_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F16, mul_mm_q4_1_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F16, mul_mm_q5_0_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F16, mul_mm_q5_1_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F16, mul_mm_q6_0_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F16, mul_mm_q8_0_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F16, mul_mm_q2_K_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F16, mul_mm_q3_K_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F16, mul_mm_q4_K_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F16, mul_mm_q5_K_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F16, mul_mm_q6_K_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F16, mul_mm_iq2_xxs_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F16, mul_mm_iq2_xs_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F16, mul_mm_iq3_xxs_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F16, mul_mm_iq3_s_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F16, mul_mm_iq2_s_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F16, mul_mm_iq1_s_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F16, mul_mm_iq1_m_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F16, mul_mm_iq1_bn_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16, mul_mm_iq2_bn_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16, mul_mm_iq4_nl_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16, mul_mm_iq4_xs_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16, mul_mm_iq4_ks_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F16, mul_mm_iq4_kss_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F16, mul_mm_iq2_k_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F16, mul_mm_iq2_ks_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F16, mul_mm_iq3_k_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F16, mul_mm_iq4_k_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F16, mul_mm_iq5_k_f16, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F16, mul_mm_iq6_k_f16, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, ctx->support_simdgroup_mm); @@ -844,9 +935,33 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,flash_attn_ext_f16_hk192_hv128, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,flash_attn_ext_f16_hk576_hv512, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,flash_attn_ext_q8_0_hk192_hv128, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,flash_attn_ext_q8_0_hk576_hv512, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80, flash_attn_ext_vec_f16_h80, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112, flash_attn_ext_vec_f16_h112, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128,flash_attn_ext_vec_f16_hk192_hv128, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,flash_attn_ext_vec_f16_hk576_hv512, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H80, flash_attn_ext_vec_q8_0_h80, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H112, flash_attn_ext_vec_q8_0_h112, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128,flash_attn_ext_vec_q8_0_hk192_hv128, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,flash_attn_ext_vec_q8_0_hk576_hv512, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); @@ -858,7 +973,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q6_0, cpy_f32_q6_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT_F32, concat_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT_F16, concat_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); } @@ -1001,17 +1117,24 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx case GGML_OP_LEAKY_RELU: return true; case GGML_OP_FLASH_ATTN_EXT: - if (op->src[1]->type != GGML_TYPE_F16) { + if (!ctx->support_simdgroup_mm) { + return false; // TODO: over-restricted for vec-kernels + } + if (op->src[1]->type != op->src[2]->type || + (op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_Q8_0)) { return false; } - if (op->src[2]->type != GGML_TYPE_F16) { - return false; + if (op->src[1]->ne[0] != op->src[2]->ne[0]) { + return (op->src[1]->ne[0] == 192 && op->src[2]->ne[0] == 128) || + (op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512); } - if (op->src[0]->ne[0] == 256) { - return false; - } - return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels + return (op->src[1]->ne[0] == 64 || op->src[1]->ne[0] == 80 || + op->src[1]->ne[0] == 96 || op->src[1]->ne[0] == 112 || + op->src[1]->ne[0] == 128 || op->src[1]->ne[0] == 256); case GGML_OP_MUL_MAT: + return ctx->support_simdgroup_reduction && + (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) && + !(op->src[0]->type >= GGML_TYPE_Q4_0_R8 && op->src[0]->type <= GGML_TYPE_Q8_K_R8); case GGML_OP_MUL_MAT_ID: return ctx->support_simdgroup_reduction && (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32); @@ -1157,7 +1280,18 @@ static void ggml_metal_encode_node( switch (dst->op) { case GGML_OP_CONCAT: { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; + GGML_ASSERT(src0->type == src1->type && src0->type == dst->type); + + id pipeline; + if (dst->type == GGML_TYPE_F32) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT_F32].pipeline; + } + else if (dst->type == GGML_TYPE_F16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT_F16].pipeline; + } + else { + GGML_ABORT("CONCAT not implemented for this type"); + } const int32_t dim = ((int32_t *) dst->op_params)[0]; @@ -1945,7 +2079,7 @@ static void ggml_metal_encode_node( if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && - src1t == GGML_TYPE_F32 && + (src1t == GGML_TYPE_F32 || src1t == GGML_TYPE_F16) && ne00 % 32 == 0 && ne00 >= 64 && (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) { //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); @@ -1960,41 +2094,84 @@ static void ggml_metal_encode_node( id pipeline = nil; - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break; - case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F32 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break; - case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break; - case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break; - case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break; - case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break; - case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break; - case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break; - case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break; - case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break; - case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32 ].pipeline; break; - case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32 ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; - case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; - case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32 ].pipeline; break; - case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F32].pipeline; break; - case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32 ].pipeline; break; - case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F32 ].pipeline; break; - case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F32 ].pipeline; break; - case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32 ].pipeline; break; - case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32 ].pipeline; break; - case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32 ].pipeline; break; - default: GGML_ABORT("MUL MAT-MAT not implemented"); + if (src1->type == GGML_TYPE_F32) { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break; + case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F32 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break; + case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break; + case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break; + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break; + case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F32 ].pipeline; break; + case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F32 ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; + case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F32 ].pipeline; break; + case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F32].pipeline; break; + case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F32 ].pipeline; break; + case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F32 ].pipeline; break; + case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F32 ].pipeline; break; + case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F32 ].pipeline; break; + case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F32 ].pipeline; break; + case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F32 ].pipeline; break; + default: GGML_ABORT("MUL MAT-MAT not implemented"); + } + } + else if (src1->type == GGML_TYPE_F16) { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F16 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F16 ].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F16 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F16 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F16 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F16 ].pipeline; break; + case GGML_TYPE_Q6_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_0_F16 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F16 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F16 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F16 ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F16 ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F16 ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F16 ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F16 ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F16].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F16 ].pipeline; break; + case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F16].pipeline; break; + case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F16 ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F16 ].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F16 ].pipeline; break; + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F16 ].pipeline; break; + case GGML_TYPE_IQ1_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_BN_F16 ].pipeline; break; + case GGML_TYPE_IQ2_BN: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_BN_F16 ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F16 ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F16 ].pipeline; break; + case GGML_TYPE_IQ4_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KS_F16 ].pipeline; break; + case GGML_TYPE_IQ4_KSS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_KSS_F16].pipeline; break; + case GGML_TYPE_IQ2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_K_F16 ].pipeline; break; + case GGML_TYPE_IQ2_KS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_KS_F16 ].pipeline; break; + case GGML_TYPE_IQ3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_K_F16 ].pipeline; break; + case GGML_TYPE_IQ4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_K_F16 ].pipeline; break; + case GGML_TYPE_IQ5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ5_K_F16 ].pipeline; break; + case GGML_TYPE_IQ6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ6_K_F16 ].pipeline; break; + default: GGML_ABORT("MUL MAT-MAT not implemented"); + } + } + else { + GGML_ABORT("Unsupported src1 type for MUL-MAT"); } [encoder setComputePipelineState:pipeline]; @@ -3204,8 +3381,9 @@ static void ggml_metal_encode_node( GGML_ASSERT(ne11 % 32 == 0); GGML_ASSERT(src0->type == GGML_TYPE_F32); - - GGML_ASSERT(ggml_are_same_shape (src1, src2)); + GGML_ASSERT(src1->type == src2->type); + GGML_ASSERT(ne11 == ne21); + GGML_ASSERT(ne12 == ne22); struct ggml_tensor * src3 = node->src[3]; @@ -3250,70 +3428,189 @@ static void ggml_metal_encode_node( bool use_vec_kernel = false; - if (ne01 >= 4 || (ne00%128 != 0)) { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 192 && ne00 != 576)) { + switch (src1->type) { + case GGML_TYPE_F16: + { + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline; + } + else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %d\n", (int)ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } + } break; + case GGML_TYPE_Q8_0: + { + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline; + } + else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %d\n", (int)ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } + } break; default: - { - GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } + { + GGML_METAL_LOG_ERROR("unsupported type: %s\n", ggml_type_name(src1->type)); + GGML_METAL_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } } } else { use_vec_kernel = true; - - switch (ne00) { - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; - //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + switch (src1->type) { + case GGML_TYPE_F16: + { + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128].pipeline; + } + else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %d\n", (int)ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } + } break; + case GGML_TYPE_Q8_0: + { + if (ne00 == 192 && ne20 == 128) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128].pipeline; + } + else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; + } else { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %d\n", (int)ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } + } break; default: - { - GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } + { + GGML_METAL_LOG_ERROR("unsupported type: %s\n", ggml_type_name(src1->type)); + GGML_METAL_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } } + } + typedef struct { + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + uint64_t nb31; + int32_t ne1; + int32_t ne2; + float scale; + float max_bias; + float m0; + float m1; + uint16_t n_head_log2; + float logit_softcap; + } ggml_metal_kargs_flash_attn_ext; + + ggml_metal_kargs_flash_attn_ext args = { + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne_12_2 =*/ ne12, + /*.ne_12_3 =*/ ne13, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb23 =*/ nb23, + /*.nb31 =*/ nb31, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + /*.logit_softcap =*/ softcap, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; if (id_src3) { - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4]; } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22]; - [encoder setBytes:&scale length:sizeof( float) atIndex:23]; - [encoder setBytes:&max_bias length:sizeof( float) atIndex:24]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:26]; - [encoder setBytes:&softcap length:sizeof(softcap) atIndex:27]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:28]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:5]; if (!use_vec_kernel) { // half8x8 kernel @@ -3324,10 +3621,19 @@ static void ggml_metal_encode_node( GGML_ASSERT(nqptg % 8 == 0); GGML_ASSERT(ncpsg % 32 == 0); + // 2*(2*ncpsg + nqptg)*(nsg) + // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float) + // + // 16*32*(nsg) + // the shared memory needed for the simdgroups to load the KV cache + // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG + // +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16)) + int64_t nsgmax = 2; while (true) { - const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2); + const size_t smem = FATTN_SMEM(nsgmax); if (smem > ctx->device.maxThreadgroupMemoryLength) { break; } @@ -3338,14 +3644,14 @@ static void ggml_metal_encode_node( // simdgroups per threadgroup (a.k.a. warps) const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; - const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); + const size_t smem = FATTN_SMEM(nsg); - //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); - - [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; - + [encoder setThreadgroupMemoryLength:smem atIndex:0]; +#undef FATTN_SMEM [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else { // half1x4 kernel const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! @@ -3355,8 +3661,27 @@ static void ggml_metal_encode_node( GGML_ASSERT(nqptg % 1 == 0); GGML_ASSERT(ncpsg % 32 == 0); + // ne00 + 2*ncpsg*(nsg) + // for each query, we load it as f16 in shared memory (ne00) + // and store the soft_max values and the mask + // + // ne00*(nsg) + // each simdgroup has a full f16 head vector in shared mem to accumulate results + // +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16)) + + int64_t nsgmax = 2; + while (true) { + const size_t smem = FATTN_SMEM(nsgmax); + if (smem > ctx->device.maxThreadgroupMemoryLength) { + break; + } + nsgmax *= 2; + } + nsgmax /= 2; + // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))); int64_t nsg = 1; while (nsg <= nsgt) { @@ -3364,13 +3689,14 @@ static void ggml_metal_encode_node( } nsg /= 2; - const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); + const size_t smem = FATTN_SMEM(nsg); - //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; - + [encoder setThreadgroupMemoryLength:smem atIndex:0]; +#undef FATTN_SMEM [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } } break; case GGML_OP_DUP: diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index a67ec336..d3a2858c 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2576,262 +2576,358 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; } -typedef void (flash_attn_ext_f16_t)( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant uint64_t & nb21, - constant uint64_t & nb22, - constant uint64_t & nb23, - constant uint64_t & nb31, - constant int64_t & ne1, - constant int64_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant float & softcap, - constant uint32_t & n_head_log2, - threadgroup half * shared, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]); +//========================================================================================== +// NOTE: this is not dequantizing - we are simply fitting the template +template +void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { + reg = (type4x4)(*src); +} + +template +void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { + reg = (type4x4)(*src); +} + +template +void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) { + reg = (type4)(*(src)); +} + +template +void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { + device const int8_t * qs = ((device const int8_t *)xb->qs); + if constexpr (is_same_v) { + const half d = xb->d; + for (int i = 0; i < 16; i++) { + reg[i/4][i%4] = (half)qs[i + 16*il] * d; + } + } else { + const float d = xb->d; + for (int i = 0; i < 16; i++) { + reg[i/4][i%4] = qs[i + 16*il] * d; + } + } +} + +template +void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) { + device const int8_t * qs = ((device const int8_t *)xb->qs); + const float d = xb->d; + for (int i = 0; i < 4; i++) { + reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d); + } +} + +typedef struct { + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + uint64_t nb31; + int32_t ne1; + int32_t ne2; + float scale; + float max_bias; + float m0; + float m1; + uint16_t n_head_log2; + float logit_softcap; +} ggml_metal_kargs_flash_attn_ext; // ref: https://arxiv.org/pdf/2307.08691.pdf -template // head size, queries per threadgroup, cache items per threadgroup -kernel void kernel_flash_attn_ext_f16( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant uint64_t & nb21, - constant uint64_t & nb22, - constant uint64_t & nb23, - constant uint64_t & nb31, - constant int64_t & ne1, - constant int64_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant float & softcap, - constant uint32_t & n_head_log2, - threadgroup half * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { +template< + typename q_t, // query types in shared memory + typename q4_t, + typename q8x8_t, + typename k_t, // key types in shared memory + typename k4x4_t, + typename k8x8_t, + typename v_t, // value types in shared memory + typename v4x4_t, + typename v8x8_t, + typename qk_t, // Q*K types + typename qk8x8_t, + typename s_t, // soft-max types + typename s8x8_t, + typename o_t, // attention accumulation types + typename o4_t, + typename o8x8_t, + typename kd4x4_t, // key type in device memory + short nl_k, + void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &), + typename vd4x4_t, // key type in device memory + short nl_v, + void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), + short DK, // K head size + short DV, // V head size + short Q = 8, // queries per threadgroup + short KV = 8, // key/value processed per each simdgroup + short C = 32> // cache items per threadgroup +kernel void kernel_flash_attn_ext( + constant ggml_metal_kargs_flash_attn_ext & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { const short nsg = ntg.y; // number of simdgroups - const short iq3 = tgpig[2]; - const short iq2 = tgpig[1]; - const short iq1 = tgpig[0]*Q; + const int iq3 = tgpig[2]; + const int iq2 = tgpig[1]; + const int iq1 = tgpig[0]*Q; - const short D4 = D/4; - const short D8 = D/8; - //const short Q8 = Q/8; - const short NW = N_SIMDWIDTH; - const short SH = (C + Q); // shared memory per simdgroup in (half) + constexpr short DK4 = DK/4; + constexpr short DK8 = DK/8; + constexpr short DK16 = DK/16; + constexpr short DV4 = DV/4; + constexpr short DV8 = DV/8; + constexpr short DV16 = DV/16; - const short T = D + 2*nsg*SH; // shared memory size per query in (half) - const short TF = T/2; // shared memory size per query in (float) - const short T4 = T/4; // shared memory size per query in (half4) + constexpr short NW = N_SIMDWIDTH; + constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) default: 72 - threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 - threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + const short TS = nsg*SH; // shared memory size per query in (s_t == float) = 288 for nsg = 4 + const short T = DK + 2*TS; // shared memory size per query in (half) = 1152 for nsg = 4 and DK = 576 + // => Q*T is 9216 for nsg = 4 and DK = 576 => 18432 bytes => overflows the 16384 bytes predicted as shmem in ggml-metal.m + + threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t + threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix + + threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory + threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t + + threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory + threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - simdgroup_half8x8 lo[D8]; + o8x8_t lo[DV8]; // For DV = 512 we have DV8 = 64 => 4096 entries per thread. Do we even have so much // load heads from Q to shared memory for (short j = sgitg; j < Q; j += nsg) { - device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); - for (short i = tiisg; i < D4; i += NW) { - if (iq1 + j < ne01) { - sq4[j*T4 + i] = (half4) q4[i]; + for (short i = tiisg; i < DK4; i += NW) { + if (iq1 + j < args.ne01) { + sq4[j*DK4 + i] = (q4_t) q4[i]; } else { - sq4[j*T4 + i] = 0.0h; + sq4[j*DK4 + i] = (q4_t) 0.0f; } } } // zero out lo - for (short i = 0; i < D8; ++i) { - lo[i] = make_filled_simdgroup_matrix(0.0h); + for (short i = 0; i < DV8; ++i) { + lo[i] = make_filled_simdgroup_matrix((o_t) 0.0f); } // zero out shared memory SH for (short j = 0; j < Q; ++j) { for (short i = tiisg; i < SH; i += NW) { - ss[j*TF + i] = 0.0f; + ss[j*TS + i] = 0.0f; } } threadgroup_barrier(mem_flags::mem_threadgroup); { - float S[Q] = { [0 ... Q-1] = 0.0h }; - float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 }; + float S[Q] = { [0 ... Q-1] = 0.0f }; + float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 }; - // assume K and V are same shape - const short ne22 = ne12; - const short ne23 = ne13; + // thread indices inside the simdgroup + // TODO: see if we can utilize quad-group functions for better performance + // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3) + const short tx = tiisg%4; + const short ty = tiisg/4; - // broadcast - const short rk2 = ne02/ne12; - const short rk3 = ne03/ne13; + // broadcast kv + //const short rk2 = args.ne02/args.ne12; + //const short rk3 = args.ne03/args.ne13; - const short rv2 = ne02/ne22; - const short rv3 = ne03/ne23; + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); - // k indices - const short ik2 = iq2/rk2; - const short ik3 = iq3/rk3; - - // v indices - const short iv2 = iq2/rv2; - const short iv3 = iq3/rv3; - - // load the queries from shared memory into local memory - simdgroup_half8x8 mq[D8]; - - for (short i = 0; i < D8; ++i) { - simdgroup_load(mq[i], sq + i*8, T); - } - - // pointer to the mask - device const half * mp = (device const half *) (mask + iq1*nb31); + const bool has_mask = mask != q; float slope = 1.0f; // ALiBi - if (max_bias > 0.0f) { - const uint32_t h = iq2; + if (args.max_bias > 0.0f) { + const short h = iq2; - const float base = h < n_head_log2 ? m0 : m1; - const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const float base = h < args.n_head_log2 ? args.m0 : args.m1; + const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; slope = pow(base, exph); } // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { const int ic = ic0 + C*sgitg; - if (ic >= ne11) { + if (ic >= args.ne11) { break; } + if (has_mask) { + // used to detect blocks full of -INF + float smax = -INFINITY; + + // load the mask in shared memory + #pragma unroll(Q) + for (short j = 0; j < Q; ++j) { + device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31); + + const float m = pm[ic + tiisg]; + + ss[j*TS + C + tiisg] = m; + smax = max(smax, m); + } + + smax = simd_max(smax); + + if (smax == -INFINITY) { + continue; + } + } + // Q*K^T { for (short cc = 0; cc < C/8; ++cc) { - simdgroup_float8x8 mqk = make_filled_simdgroup_matrix(0.h); + qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); - device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); + // this is compile-time check, so it does not have runtime overhead + if (is_same::value) { + // we can read directly from global memory + device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); - for (short i = 0; i < D8; ++i) { - simdgroup_half8x8 mk; - simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose + #pragma unroll(DK8) + for (short i = 0; i < DK8; ++i) { + k8x8_t mk; + simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10 - simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + q8x8_t mq; + simdgroup_load(mq, sq + i*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + } + } else { + for (short ii = 0; ii < DK16; ii += 4) { + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + + if (DK16%4 == 0) { + // the head is evenly divisible by 4*16 = 64, so no need for bound checks + { + k4x4_t tmp; + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + sk4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll(4) + for (short k = 0; k < 4; ++k) { + k8x8_t mk; + q8x8_t mq; + + simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + } + } else { + if (ii + tx < DK16) { + k4x4_t tmp; + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + sk4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + for (short k = 0; k < 4 && ii + k < DK16; ++k) { + k8x8_t mk; + q8x8_t mq; + + simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + } + } + } } - simdgroup_store(mqk, ss + 8*cc, TF, 0, false); + // cast qk_t -> s_t + //s8x8_t mqks(1.0f); + //simdgroup_multiply(mqks, mqk, mqks); + //simdgroup_store(mqks, ss + 8*cc, TS, 0, false); - const short tx = tiisg%4; - const short ty = tiisg/4; - - // mqk = mqk*scale - ss[8*cc + ty*TF + 2*tx + 0] *= scale; - ss[8*cc + ty*TF + 2*tx + 1] *= scale; - - if (softcap != 0.0f) { - ss[8*cc + ty*TF + 2*tx + 0] = softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]); - ss[8*cc + ty*TF + 2*tx + 1] = softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]); - } - - if (mask != q) { - // mqk = mqk*scale + mask*slope - ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0]; - ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1]; - } + simdgroup_store(mqk, ss + 8*cc, TS, 0, false); } } - // used to detect blocks full of -INF - float smax = -INFINITY; - // online softmax { - float ms[Q]; - - for (short j = 0; j < Q; ++j) { - const short p = tiisg; - + for (ushort j = 0; j < Q; ++j) { const float m = M[j]; - const float s = ss[j*TF + p]; - smax = simd_max(max(smax, s)); + // scale and apply the logitcap / mask + float s = ss[j*TS + tiisg]*args.scale; + + if (args.logit_softcap != 0.0f) { + s = args.logit_softcap*precise::tanh(s); + } + + // mqk = mqk + mask*slope + s += slope*ss[j*TS + C + tiisg]; + M[j] = simd_max(max(M[j], s)); - ms[j] = exp(m - M[j]); - const float vs = exp(s - M[j]); + const float ms = exp(m - M[j]); + const float vs = exp(s - M[j]); - S[j] = S[j]*ms[j] + simd_sum(vs); + S[j] = S[j]*ms + simd_sum(vs); // the P matrix from the paper (Q rows, C columns) - ss[j*TF + p] = vs; - } + ss[j*TS + tiisg] = vs; - // create a QxQ diagonal matrix for rescaling the output - if (tiisg < Q) { - ss[tiisg*TF + C + tiisg] = ms[tiisg]; + // create a QxQ diagonal matrix for rescaling the output + if (tiisg == j) { + ss[j*TS + 2*C + j] = ms; + } } } - // skip -INF blocks - if (smax == -INFINITY) { - continue; - } - // O = diag(ms)*O { - simdgroup_float8x8 mm; - simdgroup_load(mm, ss + C, TF, 0, false); + s8x8_t mm; + simdgroup_load(mm, ss + 2*C, TS, 0, false); - for (short i = 0; i < D8; ++i) { + #pragma unroll(DV8) + for (short i = 0; i < DV8; ++i) { simdgroup_multiply(lo[i], mm, lo[i]); } } @@ -2839,16 +2935,64 @@ kernel void kernel_flash_attn_ext_f16( // O = O + (Q*K^T)*V { for (short cc = 0; cc < C/8; ++cc) { - device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + s8x8_t ms; + simdgroup_load(ms, ss + 8*cc, TS, 0, false); - for (short i = 0; i < D8; ++i) { - simdgroup_half8x8 mk; - simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); + if (is_same::value) { + // we can read directly from global memory + device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); - simdgroup_float8x8 mv; - simdgroup_load(mv, ss + 8*cc, TF, 0, false); + #pragma unroll(DV8) + for (short i = 0; i < DV8; ++i) { + v8x8_t mv; + simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20 - simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]); + simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]); + } + } else { + for (short ii = 0; ii < DV16; ii += 4) { + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + + if (DV16%4 == 0) { + // no need for bound checks + { + v4x4_t tmp; + deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp); + sv4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll(4) + for (short k = 0; k < 4; ++k) { + v8x8_t mv; + + simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]); + + simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]); + } + } else { + if (ii + tx < DV16) { + v4x4_t tmp; + deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp); + sv4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + for (short k = 0; k < 4 && ii + k < DV16; ++k) { + v8x8_t mv; + + simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]); + + simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]); + } + } + } } } } @@ -2857,23 +3001,23 @@ kernel void kernel_flash_attn_ext_f16( // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (short j = 0; j < Q; ++j) { if (tiisg == 0) { - ss[j*TF + 0] = S[j]; - ss[j*TF + 1] = M[j]; + ss[j*TS + 0] = S[j]; + ss[j*TS + 1] = M[j]; } } } // reduce the warps sequentially - for (short sg = 1; sg < nsg; ++sg) { - float S = { 0.0h }; - float M = { -FLT_MAX/2 }; + for (ushort sg = 1; sg < nsg; ++sg) { + float S = { 0.0f }; + float M = { -__FLT16_MAX__/2 }; threadgroup_barrier(mem_flags::mem_threadgroup); // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { - for (short i = 0; i < D8; ++i) { - simdgroup_store(lo[i], sq + i*8, T, 0, false); + for (short i = 0; i < DV8; ++i) { + simdgroup_store(lo[i], so + i*8, DV, 0, false); } } @@ -2882,11 +3026,11 @@ kernel void kernel_flash_attn_ext_f16( // the first simdgroup accumulates the results from the other simdgroups if (sgitg == 0) { for (short j = 0; j < Q; ++j) { - const float S0 = ss[j*TF + 0]; - const float S1 = ss[j*TF + sg*SH + 0]; + const float S0 = ss[j*TS + 0]; + const float S1 = ss[j*TS + sg*SH + 0]; - const float M0 = ss[j*TF + 1]; - const float M1 = ss[j*TF + sg*SH + 1]; + const float M0 = ss[j*TS + 1]; + const float M1 = ss[j*TS + sg*SH + 1]; M = max(M0, M1); @@ -2896,25 +3040,27 @@ kernel void kernel_flash_attn_ext_f16( S = S0*ms0 + S1*ms1; if (tiisg == 0) { - ss[j*TF + 0] = S; - ss[j*TF + 1] = M; + ss[j*TS + 0] = S; + ss[j*TS + 1] = M; - ss[j*TF + C + j ] = ms0; - ss[j*TF + C + j + sg*SH] = ms1; + ss[j*TS + 2*C + j ] = ms0; + ss[j*TS + 2*C + j + sg*SH] = ms1; } } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 { - simdgroup_half8x8 t; - simdgroup_float8x8 ms0; - simdgroup_float8x8 ms1; + s8x8_t ms0; + s8x8_t ms1; - simdgroup_load(ms0, ss + C, TF, 0, false); - simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false); + simdgroup_load(ms0, ss + 2*C, TS, 0, false); + simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false); - for (short i = 0; i < D8; ++i) { - simdgroup_load (t, sq + i*8, T, 0, false); + #pragma unroll(DV8) + for (short i = 0; i < DV8; ++i) { + o8x8_t t; + + simdgroup_load (t, so + i*8, DV, 0, false); simdgroup_multiply(t, ms1, t); simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); @@ -2925,8 +3071,8 @@ kernel void kernel_flash_attn_ext_f16( // store result to shared memory (reuse sq) if (sgitg == 0) { - for (short i = 0; i < D8; ++i) { - simdgroup_store(lo[i], sq + i*8, T, 0, false); + for (short i = 0; i < DV8; ++i) { + simdgroup_store(lo[i], so + i*8, DV, 0, false); } } @@ -2934,206 +3080,246 @@ kernel void kernel_flash_attn_ext_f16( // final rescale with 1/S and store to global memory if (sgitg == 0) { - for (short j = 0; j < Q && iq1 + j < ne01; ++j) { - const float S = ss[j*TF + 0]; + for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) { + const float S = ss[j*TS + 0]; - for (short i = tiisg; i < D4; i += NW) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; + for (short i = tiisg; i < DV4; i += NW) { + dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S; } } } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; -template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>; -template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; -//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; +// TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as +// template to be able to explore different combinations +// +#define FA_TYPES \ + half, half4, simdgroup_half8x8, \ + half, half4x4, simdgroup_half8x8, \ + half, half4x4, simdgroup_half8x8, \ + float, simdgroup_float8x8, \ + float, simdgroup_float8x8, \ + half, half4, simdgroup_half8x8 -template // head size, queries per threadgroup, cache items per threadgroup -kernel void kernel_flash_attn_ext_vec_f16( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant uint64_t & nb21, - constant uint64_t & nb22, - constant uint64_t & nb23, - constant uint64_t & nb31, - constant int64_t & ne1, - constant int64_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant float & softcap, - constant uint32_t & n_head_log2, - threadgroup half * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { +typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; + +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +#undef FA_TYPES + +template< + typename q4_t, // query types in shared memory + typename k4_t, // key types in shared memory + typename v4_t, // value types in shared memory + typename qk_t, // Q*K types + typename s_t, // soft-max types + typename s4_t, + typename o4_t, // attention accumulation types + typename kd4_t, // key type in device memory + short nl_k, + void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &), + typename vd4_t, // key type in device memory + short nl_v, + void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), + short DK, // K head size + short DV, // V head size + short NE = 4, // head elements per thread + short Q = 1, // queries per threadgroup + short C = 32> // cache items per threadgroup +kernel void kernel_flash_attn_ext_vec( + constant ggml_metal_kargs_flash_attn_ext & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { const short nsg = ntg.y; // number of simdgroups - const short iq3 = tgpig[2]; - const short iq2 = tgpig[1]; - const short iq1 = tgpig[0]; + const int iq3 = tgpig[2]; + const int iq2 = tgpig[1]; + const int iq1 = tgpig[0]; - const short D4 = D/4; - const short NW = N_SIMDWIDTH; - const short SH = (C + Q); // shared memory per simdgroup in (half) + constexpr short DK4 = DK/4; + constexpr short DV4 = DV/4; + constexpr short NW = N_SIMDWIDTH; + constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads + constexpr short SH = 4*C; // shared memory per simdgroup - const short T = D + 2*nsg*SH; // shared memory size per query in (half) + const short T = DK + nsg*SH; // shared memory size per query in (half) - float slope = 1.0f; + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t + threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask + threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results - // ALiBi - if (max_bias > 0.0f) { - const uint32_t h = iq2; - - const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 - threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix - threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4 - threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results - - // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - half4 lo[D4/NW]; + // store the result for all queries in local memory (the O matrix from the paper) + o4_t lo[DV4/NL]; // load heads from Q to shared memory - device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); - for (short i = tiisg; i < D4; i += NW) { - if (iq1 < ne01) { - sq4[i] = (half4) q4[i]; + for (short i = tiisg; i < DK4; i += NW) { + if (iq1 < args.ne01) { + sq4[i] = (q4_t) q4[i]; } else { - sq4[i] = 0.0h; + sq4[i] = (q4_t) 0.0f; } } // zero out lo - for (short i = tiisg; i < D4; i += NW) { - lo[i/NW] = 0.0h; + for (short i = 0; i < DV4/NL; ++i) { + lo[i] = (o4_t) 0.0f; } // zero out shared memory SH for (short i = tiisg; i < SH/4; i += NW) { - ss4[i] = 0.0h; + ss4[i] = (s4_t) 0.0f; } threadgroup_barrier(mem_flags::mem_threadgroup); { - float S = { 0.0h }; - float M = { -FLT_MAX/2 }; + float S = 0.0f; + float M = -__FLT16_MAX__/2; - // assume K and V are same shape - const short ne22 = ne12; - const short ne23 = ne13; + // thread indices inside the simdgroup + const short tx = tiisg%NL; + const short ty = tiisg/NL; - // broadcast - const short rk2 = ne02/ne12; - const short rk3 = ne03/ne13; + // broadcast kv + //const short rk2 = args.ne02/args.ne12; + //const short rk3 = args.ne03/args.ne13; - const short rv2 = ne02/ne22; - const short rv3 = ne03/ne23; + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); - // k indices - const short ik2 = iq2 / rk2; - const short ik3 = iq3 / rk3; - - // v indices - const short iv2 = iq2 / rv2; - const short iv3 = iq3 / rv3; - - // load the queries from shared memory into local memory - float4 mq[D4]; - - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - mq[i] = (float4)sq4[i]; - } + const bool has_mask = mask != q; // pointer to the mask - device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); + device const half * pm = (device const half *) (mask + iq1*args.nb31); + + float slope = 1.0f; + + // ALiBi + if (args.max_bias > 0.0f) { + const short h = iq2; + + const float base = h < args.n_head_log2 ? args.m0 : args.m1; + const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; + + slope = pow(base, exph); + } // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { const int ic = ic0 + C*sgitg; - if (ic >= ne11) { + if (ic >= args.ne11) { break; } + if (has_mask) { + sm[tiisg] = pm[ic + tiisg]; + } + // Q*K^T { -#pragma unroll - for (short cc = 0; cc < C/4; ++cc) { - float4 mqk = { 0.0h }; + // each simdgroup processes 1 query and NE (NW/NL) head elements + for (short cc = 0; cc < C/NE; ++cc) { + qk_t mqk = 0.0f; - device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); + device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); -#pragma unroll - for (short ii = 0; ii < D4; ii += NW) { - const short i = ii + tiisg; + #pragma unroll(DK4/NL) + for (short ii = 0; ii < DK4; ii += NL) { + const short i = ii + tx; - float4x4 mk; - mk[0] = (float4)pk4[i + 0*(nb11/8)]; - mk[1] = (float4)pk4[i + 1*(nb11/8)]; - mk[2] = (float4)pk4[i + 2*(nb11/8)]; - mk[3] = (float4)pk4[i + 3*(nb11/8)]; + k4_t mk; + deq_k_t4(pk + i/nl_k, i%nl_k, mk); - mqk += (float4) (mq[i] * mk); + // note: this is less precise than the version below + //mqka[0] += dot(mq[0], mk[0]); + //mqka[1] += dot(mq[1], mk[1]); + //mqka[2] += dot(mq[2], mk[2]); + //mqka[3] += dot(mq[3], mk[3]); + + //q4x4_t mq = sq4x4[i]; + //mqka[0] += dot((float4) mq[0], (float4) mk[0]); + //mqka[1] += dot((float4) mq[1], (float4) mk[1]); + //mqka[2] += dot((float4) mq[2], (float4) mk[2]); + //mqka[3] += dot((float4) mq[3], (float4) mk[3]); + + mqk += dot((float4) mk, (float4) sq4[i]); } - // reduce the results from the threads in the simdgroup - mqk += simd_shuffle_down(mqk, 16); - mqk += simd_shuffle_down(mqk, 8); - mqk += simd_shuffle_down(mqk, 4); - mqk += simd_shuffle_down(mqk, 2); - mqk += simd_shuffle_down(mqk, 1); + static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails + + // simdgroup reduce (NE = 4) + // [ 0 .. 7] -> [ 0] + // [ 8 .. 15] -> [ 8] + // [16 .. 23] -> [16] + // [24 .. 31] -> [24] + if (NE <= 1) { + mqk += simd_shuffle_down(mqk, 16); + } + if (NE <= 2) { + mqk += simd_shuffle_down(mqk, 8); + } + if (NE <= 4) { + mqk += simd_shuffle_down(mqk, 4); + } + if (NE <= 8) { + mqk += simd_shuffle_down(mqk, 2); + } + if (NE <= 16) { + mqk += simd_shuffle_down(mqk, 1); + } // mqk = mqk*scale + mask*slope - if (tiisg == 0) { - mqk *= scale; - if (softcap != 0.0f) { - mqk = softcap*precise::tanh(mqk); + if (tx == 0) { + mqk *= args.scale; + + if (args.logit_softcap != 0.0f) { + mqk = args.logit_softcap*precise::tanh(mqk); } - mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f; - ss4[cc] = mqk; + mqk += sm[NE*cc + ty]*slope; + + ss[NE*cc + ty] = mqk; } - } } + simdgroup_barrier(mem_flags::mem_threadgroup); + // online softmax { - const short p = tiisg; - const float m = M; - const float s = ss[p]; + const float s = ss[tiisg]; M = simd_max(max(M, s)); @@ -3143,47 +3329,96 @@ kernel void kernel_flash_attn_ext_vec_f16( S = S*ms + simd_sum(vs); // the P matrix from the paper (Q rows, C columns) - ss[p] = vs; + ss[tiisg] = vs; // O = diag(ms)*O -#pragma unroll - for (short ii = 0; ii < D4; ii += NW) { - const short i = ii + tiisg; - lo[i/NW] *= ms; + #pragma unroll(DV4/NL) + for (short ii = 0; ii < DV4; ii += NL) { + lo[ii/NL] *= ms; } } + simdgroup_barrier(mem_flags::mem_threadgroup); + // O = O + (Q*K^T)*V { -#pragma unroll - for (short cc = 0; cc < C/4; ++cc) { - device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23)); + //#pragma unroll(C/NE) + for (short cc = 0; cc < C/NE; ++cc) { + device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); -#pragma unroll - for (short ii = 0; ii < D4; ii += NW) { - const short i = ii + tiisg; + const s4_t ms(ss[NE*cc + ty]); - lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; - lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; - lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; - lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; + #pragma unroll(DV4/NL) + for (short ii = 0; ii < DV4; ii += NL) { + const short i = ii + tx; + + v4_t mv; + deq_v_t4(pv4 + i/nl_v, i%nl_v, mv); + + lo[ii/NL] += o4_t(float4(mv)*float4(ms)); } } } - } // these are needed for reducing the results from the simdgroups (reuse the ss buffer) if (tiisg == 0) { - ss[0] = S; - ss[1] = M; + ss[0] = (s_t) S; + ss[1] = (s_t) M; } } + // simdgroup reduce (NE = 4) + // [ 0, 8, 16, 24] -> [ 0] + // [ 1, 9, 17, 25] -> [ 1] + // [ 2, 10, 18, 26] -> [ 2] + // [ 3, 11, 19, 27] -> [ 3] + // [ 4, 12, 20, 28] -> [ 4] + // [ 5, 13, 21, 29] -> [ 5] + // [ 6, 14, 22, 30] -> [ 6] + // [ 7, 15, 23, 31] -> [ 7] + for (short ii = 0; ii < DV4; ii += NL) { + if (NE > 1) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16); + } + + if (NE > 2) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8); + } + + if (NE > 4) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4); + } + + if (NE > 8) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2); + } + + if (NE > 16) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + // store results to shared memory - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - sr4[i] = lo[ii/NW]; + for (short i = tiisg; i < DV4; i += NL) { + sr4[i] = lo[i/NL]; } threadgroup_barrier(mem_flags::mem_threadgroup); @@ -3191,11 +3426,11 @@ kernel void kernel_flash_attn_ext_vec_f16( // parallel reduce for (short r = nsg/2; r > 0; r >>= 1) { if (sgitg < r) { - const float S0 = ss[ 0]; - const float S1 = ss[r*SH + 0]; + const float S0 = ss[ 0]; + const float S1 = ss[r*(SH/2) + 0]; - const float M0 = ss[ 1]; - const float M1 = ss[r*SH + 1]; + const float M0 = ss[ 1]; + const float M1 = ss[r*(SH/2) + 1]; const float M = max(M0, M1); @@ -3210,9 +3445,8 @@ kernel void kernel_flash_attn_ext_vec_f16( } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1; + for (short i = tiisg; i < DV4; i += NW) { + sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1; } } @@ -3225,15 +3459,50 @@ kernel void kernel_flash_attn_ext_vec_f16( if (sgitg == 0) { const float S = ss[0]; - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S; + for (short i = tiisg; i < DV4; i += NW) { + dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S; } } } -template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; -//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; +// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem +// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max +// +#define FA_TYPES \ + half4, \ + half4, \ + half4, \ + float, \ + float, float4, \ + half4 + +typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h80")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h80")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h112")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h112")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +#undef FA_TYPES template kernel void kernel_cpy( @@ -3809,7 +4078,8 @@ kernel void kernel_cpy_f32_iq4_nl( } } -kernel void kernel_concat( +template +static inline void concat_impl( device const char * src0, device const char * src1, device char * dst, @@ -3849,21 +4119,93 @@ kernel void kernel_concat( int64_t o[4] = {0, 0, 0, 0}; o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); - device const float * x; + device const src_t * x; for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { - x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); + x = (device const src_t *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); } else { - x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); + x = (device const src_t *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); } - device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device src_t * y = (device src_t *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); *y = *x; } } +kernel void kernel_concat_f32( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int32_t & dim, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + concat_impl(src0, src1, dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, dim, tgpig, tpitg, ntg); +} + +kernel void kernel_concat_f16( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int32_t & dim, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + concat_impl(src0, src1, dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, dim, tgpig, tpitg, ntg); +} + + + void kernel_mul_mv_q2_K_f32_impl( device const void * src0, device const float * src1, @@ -7092,21 +7434,6 @@ kernel void kernel_mul_mv_iq6_k_f32( //============================= templates and their specializations ============================= // NOTE: this is not dequantizing - we are simply fitting the template -template -void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { - float4x4 temp = *(((device float4x4 *)src)); - for (int i = 0; i < 16; i++){ - reg[i/4][i%4] = temp[i/4][i%4]; - } -} - -template -void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { - half4x4 temp = *(((device half4x4 *)src)); - for (int i = 0; i < 16; i++){ - reg[i/4][i%4] = temp[i/4][i%4]; - } -} template void dequantize_bf16(device const half4x4 * src, short il, thread type4x4 & reg) { @@ -7220,22 +7547,6 @@ void dequantize_q6_0(device const block_q6_0 *xb, short il, thread type4x4 & reg } } -template -void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { - device const int8_t * qs = ((device const int8_t *)xb->qs); - if constexpr (is_same_v) { - const half d = xb->d; - for (int i = 0; i < 16; i++) { - reg[i/4][i%4] = (half)qs[i + 16*il] * d; - } - } else { - const float d = xb->d; - for (int i = 0; i < 16; i++) { - reg[i/4][i%4] = qs[i + 16*il] * d; - } - } -} - template void dequantize_q2_K(device const block_q2_K * xb, short il, thread type4x4 & reg) { const float d = xb->d; @@ -7967,7 +8278,7 @@ struct DequantizerRSBN { }; // each block_q contains 16*nl weights -template +template kernel void kernel_mul_mm(device const uchar * src0, device const uchar * src1, device float * dst, @@ -8017,8 +8328,8 @@ kernel void kernel_mul_mm(device const uchar * src0, uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); - device const char * cx = (device const char *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0); - device const float * y = (device const float *)(src1 + device const char * cx = (device const char *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0); + device const src1_t * y = (device const src1_t *)(src1 + nb12 * im + nb11 * (r1 * BLOCK_SIZE_N + thread_col) + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); @@ -8038,7 +8349,12 @@ kernel void kernel_mul_mm(device const uchar * src0, + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; } - *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); + if (is_same_v) { + *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); + } else { + half2x4 h = *((device half2x4 *)y); + *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = (float2x4)h; + } deq.next(); y += BLOCK_SIZE_K; @@ -8381,41 +8697,76 @@ template [[host_name("kernel_get_rows_iq2_ks")]] kernel get_rows_q_t kernel_get template using DD = DefaultDequantizer; -typedef decltype(kernel_mul_mm>) mat_mm_t; +typedef decltype(kernel_mul_mm, float>) mat_mm_t; + +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_q6_0_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq2_k_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq3_k_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq4_k_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq5_k_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq6_k_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq4_ks_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq4_kss_f32")]] kernel mat_mm_t kernel_mul_mm, float>; +template [[host_name("kernel_mul_mm_iq2_ks_f32")]] kernel mat_mm_t kernel_mul_mm, float>; + +template [[host_name("kernel_mul_mm_f32_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_f16_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_bf16_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_q6_0_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_q5_K_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_q6_K_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq2_xxs_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq2_xs_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq3_xxs_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq3_s_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq2_k_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq3_k_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq4_k_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq5_k_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq6_k_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq1_bn_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq2_bn_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq4_ks_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq4_kss_f16")]] kernel mat_mm_t kernel_mul_mm, half>; +template [[host_name("kernel_mul_mm_iq2_ks_f16")]] kernel mat_mm_t kernel_mul_mm, half>; -template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_q6_0_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq2_k_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq3_k_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq4_k_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq5_k_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq6_k_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq1_bn_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq2_bn_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq4_ks_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq4_kss_f32")]] kernel mat_mm_t kernel_mul_mm>; -template [[host_name("kernel_mul_mm_iq2_ks_f32")]] kernel mat_mm_t kernel_mul_mm>; // // indirect matrix-matrix multiplication From c616306a011cb93d6142285f85cb6803a1a02564 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 4 Apr 2025 11:04:19 +0200 Subject: [PATCH 072/132] Add -flax-vector-conversions for GCC on ARM (#311) Co-authored-by: Iwan Kawrakow --- ggml/src/CMakeLists.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index f7f15fbd..0aa08adf 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -1147,6 +1147,10 @@ if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR if (GGML_SVE) list(APPEND ARCH_FLAGS -march=armv8.6-a+sve) endif() + if (CMAKE_CXX_COMPILER_ID MATCHES "GNU") + # else we fail on Gravitons and such + list(APPEND ARCH_FLAGS -flax-vector-conversions) + endif() endif() elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND From ec84855c6ae5a08686f3e5d8010e38064269deb3 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 5 Apr 2025 14:31:27 +0200 Subject: [PATCH 073/132] We need to synchronize before using device to host async memcpy (#313) Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda.cu | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 5f4731c8..0096a00b 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2208,7 +2208,7 @@ static inline void prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n std::vector ids_host(ggml_nbytes(ids)); const char * ids_dev = (const char *) ids->data; CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); - //CUDA_CHECK(cudaStreamSynchronize(stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); std::vector rmapping(ids->ne[1]*n_ids); moe_counts.resize(n_as, 0); @@ -2239,6 +2239,7 @@ static inline void prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n for (int i = 0; i < (int)n_as; ++i) cum_moe_counts[i] -= moe_counts[i]; CUDA_CHECK(cudaMemcpyAsync(dev_row_mapping.get(), rmapping.data(), cum_moe_counts[n_as]*sizeof(mmid_row_mapping), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); } @@ -2302,11 +2303,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * const int64_t n_as = ne02; const int64_t n_ids = ids->ne[0]; - //std::vector ids_host(ggml_nbytes(ids)); - //const char * ids_dev = (const char *) ids->data; - //CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); - //CUDA_CHECK(cudaStreamSynchronize(stream)); - ggml_tensor src0_row = *src0; ggml_tensor src1_row = *src1; ggml_tensor dst_row = *dst; @@ -2335,12 +2331,12 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * std::vector ids_host(ggml_nbytes(ids)); const char * ids_dev = (const char *) ids->data; CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { for (int64_t id = 0; id < n_ids; id++) { const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); if (i02 < 0 || i02 >= n_as) continue; - //GGML_ASSERT(i02 >= 0 && i02 < n_as); const int64_t i11 = id % ne11; const int64_t i12 = iid1; From abbabf7ca14a5bfe2f141362876ccd60b8968d48 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 7 Apr 2025 10:41:40 +0200 Subject: [PATCH 074/132] Update LICENSE I did not realize until today that the [ggml authors](https://github.com/ggml-org/ggml/blob/master/AUTHORS) is not the same thing as the [llama.cpp authors](https://github.com/ggml-org/llama.cpp/blob/master/AUTHORS). This PR corrects my mistake. --- LICENSE | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/LICENSE b/LICENSE index 03f0ee9c..34fbbe73 100644 --- a/LICENSE +++ b/LICENSE @@ -1,7 +1,8 @@ MIT License Copyright (c) 2023-2024 The ggml authors -Copyright (c) 2024 Iwan Kawrakow +Copyright (c) 2023-2024 The llama.cpp authors +Copyright (c) 2024-2025 Iwan Kawrakow Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal From a051f08b8f059fa10dd089d231b975291c122e9d Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 7 Apr 2025 10:43:26 +0200 Subject: [PATCH 075/132] Add copyright notices (#317) Co-authored-by: Iwan Kawrakow --- common/common.cpp | 7 +++++++ common/common.h | 7 +++++++ examples/llama-bench/llama-bench.cpp | 7 +++++++ examples/perplexity/perplexity.cpp | 7 +++++++ examples/quantize-stats/quantize-stats.cpp | 7 +++++++ examples/quantize/quantize.cpp | 7 +++++++ ggml/include/ggml.h | 7 +++++++ ggml/src/ggml-common.h | 2 +- ggml/src/ggml-cuda.cu | 6 ++++++ ggml/src/ggml-cuda/argsort.cu | 6 ++++++ ggml/src/ggml-cuda/argsort.cuh | 6 ++++++ ggml/src/ggml-cuda/binbcast.cu | 7 +++++++ ggml/src/ggml-cuda/concat.cu | 7 +++++++ ggml/src/ggml-cuda/convert.cuh | 7 +++++++ ggml/src/ggml-cuda/cpy.cu | 7 +++++++ ggml/src/ggml-cuda/fattn-common.cuh | 7 +++++++ ggml/src/ggml-cuda/fattn-tile-f16.cu | 7 +++++++ ggml/src/ggml-cuda/fattn-tile-f32.cu | 7 +++++++ ggml/src/ggml-cuda/fattn-vec-f16.cuh | 7 +++++++ ggml/src/ggml-cuda/fattn-vec-f32.cuh | 7 +++++++ ggml/src/ggml-cuda/fattn-wmma-f16.cuh | 7 +++++++ ggml/src/ggml-cuda/fattn.cu | 7 +++++++ ggml/src/ggml-cuda/getrows.cu | 7 +++++++ ggml/src/ggml-cuda/iqk_mmvq.cuh | 6 ++++++ ggml/src/ggml-cuda/mmq.cu | 7 +++++++ ggml/src/ggml-cuda/mmq.cuh | 7 +++++++ ggml/src/ggml-cuda/mmvq.cu | 7 +++++++ ggml/src/ggml-cuda/mmvq.cuh | 7 +++++++ ggml/src/ggml-cuda/quantize.cu | 7 +++++++ ggml/src/ggml-cuda/quantize.cuh | 7 +++++++ ggml/src/ggml-cuda/rope.cu | 7 +++++++ ggml/src/ggml-cuda/softcap.cu | 6 ++++++ ggml/src/ggml-cuda/softcap.cuh | 7 +++++++ ggml/src/ggml-cuda/softmax.cu | 7 +++++++ ggml/src/ggml-cuda/softmax.cuh | 7 +++++++ ggml/src/ggml-cuda/unary.cu | 7 +++++++ ggml/src/ggml-cuda/unary.cuh | 7 +++++++ ggml/src/ggml-quants.c | 2 +- ggml/src/iqk/iqk_config.h | 6 ++++++ ggml/src/iqk/iqk_flash_attn.cpp | 6 ++++++ ggml/src/iqk/iqk_flash_impl.h | 6 ++++++ ggml/src/iqk/iqk_quantize.h | 6 ++++++ include/llama.h | 7 +++++++ src/llama-impl.h | 7 +++++++ src/llama.cpp | 7 +++++++ 45 files changed, 294 insertions(+), 2 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 5b3d60b3..f0c618e0 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2025 The llama.cpp authors +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #if defined(_MSC_VER) #define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING #endif diff --git a/common/common.h b/common/common.h index 79c5fe96..b4f75236 100644 --- a/common/common.h +++ b/common/common.h @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2025 The llama.cpp authors +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + // Various helper functions and utilities #pragma once diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index ae79ad02..74f51494 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2025 The llama.cpp authors +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include #include #include diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 95aedce6..07ac88af 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2025 The llama.cpp authors +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.h" #include "llama.h" diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index ff4e9bd4..79905f54 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2025 The llama.cpp authors +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #define LLAMA_API_INTERNAL #include "common.h" #include "ggml.h" diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 1d00c874..0a91b537 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2025 The llama.cpp authors +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.h" #include "llama.h" diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 7cc9100d..beeb3c09 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2025 The ggml authors +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #pragma once // diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 59702e32..5a6417fc 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -1,6 +1,6 @@ // -// Copyright (C) 2023-2024 The ggml authors // Copyright (C) 2024 Iwan Kawrakow +// Copyright (C) 2023-2024 The ggml authors // MIT license // SPDX-License-Identifier: MIT // diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 0096a00b..63e16d52 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -1,3 +1,9 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// #include "ggml-cuda.h" #include "ggml.h" #include "ggml-backend-impl.h" diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 1734b771..df214082 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -1,3 +1,9 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// #include "argsort.cuh" template diff --git a/ggml/src/ggml-cuda/argsort.cuh b/ggml/src/ggml-cuda/argsort.cuh index 4bafa2d7..7bbfdf4d 100644 --- a/ggml/src/ggml-cuda/argsort.cuh +++ b/ggml/src/ggml-cuda/argsort.cuh @@ -1,3 +1,9 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// #include "common.cuh" void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index f217488e..701b0f80 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "binbcast.cuh" static __device__ __forceinline__ float op_repeat(const float a, const float b) { diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index f89ce974..ee98bf18 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "concat.cuh" // contiguous kernels diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index 0efcecde..5d107450 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 3f2d0089..3b87cbad 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "cpy.cuh" #include "convert.cuh" diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index a46f03e5..08bb5cd3 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #pragma once #include "common.cuh" diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index bf2a4521..420f0bb0 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #include "fattn-common.cuh" #include "fattn-tile-f16.cuh" diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 28846561..f525f1bb 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #include "fattn-common.cuh" #include "fattn-tile-f32.cuh" diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index b6ba07e4..2cf4f4ef 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #include "fattn-common.cuh" diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 404afe2e..c91cef3d 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #include "fattn-common.cuh" diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index c5ffc7d1..d39c6a6e 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #include "fattn-common.cuh" diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 6329908d..8c3f4000 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #include "fattn-common.cuh" #include "fattn-tile-f16.cuh" diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 973b6526..f734271a 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "getrows.cuh" #include "dequantize.cuh" diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cuh b/ggml/src/ggml-cuda/iqk_mmvq.cuh index 15df20f5..1f55ddb9 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cuh +++ b/ggml/src/ggml-cuda/iqk_mmvq.cuh @@ -1,3 +1,9 @@ +// +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" void mul_mat_vec_iq2_k_q8_1_cuda( diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index f9fc2438..3b959182 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "mmq.cuh" void ggml_cuda_op_mul_mat_q( diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 416b4336..88e023a1 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #pragma once #include "common.cuh" diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index b9e9c216..69ccef1d 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "mmvq.cuh" #include "iqk_mmvq.cuh" #include "vecdotq.cuh" diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index c0699b04..525c6bc0 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 90dfac92..953eb9d9 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "quantize.cuh" #include diff --git a/ggml/src/ggml-cuda/quantize.cuh b/ggml/src/ggml-cuda/quantize.cuh index 69622e18..0be5bf0e 100644 --- a/ggml/src/ggml-cuda/quantize.cuh +++ b/ggml/src/ggml-cuda/quantize.cuh @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #pragma once #include "common.cuh" diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 42ca72af..dd25a2ce 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "rope.cuh" struct rope_corr_dims { diff --git a/ggml/src/ggml-cuda/softcap.cu b/ggml/src/ggml-cuda/softcap.cu index 499025d1..268b9ae5 100644 --- a/ggml/src/ggml-cuda/softcap.cu +++ b/ggml/src/ggml-cuda/softcap.cu @@ -1,3 +1,9 @@ +// +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "softcap.cuh" static __global__ void softcap_f32(const float * x, float * dst, float s_before, float s_after, const int k) { diff --git a/ggml/src/ggml-cuda/softcap.cuh b/ggml/src/ggml-cuda/softcap.cuh index 2b875bfb..4c345b2e 100644 --- a/ggml/src/ggml-cuda/softcap.cuh +++ b/ggml/src/ggml-cuda/softcap.cuh @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #define CUDA_SOFTCAP_BLOCK_SIZE 256 diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index 6f3056e6..c006301f 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #include "softmax.cuh" diff --git a/ggml/src/ggml-cuda/softmax.cuh b/ggml/src/ggml-cuda/softmax.cuh index 49a83dfa..b97e9a92 100644 --- a/ggml/src/ggml-cuda/softmax.cuh +++ b/ggml/src/ggml-cuda/softmax.cuh @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #define CUDA_SOFT_MAX_BLOCK_SIZE 1024 diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index c422abbc..6312f25c 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "unary.cuh" static __global__ void gelu_f32(const float * x, float * dst, const int k) { diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index e55c4262..9bcd30a8 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #define CUDA_GELU_BLOCK_SIZE 256 diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index e43a9b22..e5a9ec8c 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -1,6 +1,6 @@ // -// Copyright (C) 2023-2024 The ggml authors // Copyright (C) 2024 Iwan Kawrakow +// Copyright (C) 2023-2024 The ggml authors // MIT license // SPDX-License-Identifier: MIT // diff --git a/ggml/src/iqk/iqk_config.h b/ggml/src/iqk/iqk_config.h index fa4972b8..29645f4e 100644 --- a/ggml/src/iqk/iqk_config.h +++ b/ggml/src/iqk/iqk_config.h @@ -1,3 +1,9 @@ +// +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #pragma once #if defined IQK_IMPLEMENT diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp index 2c6da2b9..e302c944 100644 --- a/ggml/src/iqk/iqk_flash_attn.cpp +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -1,3 +1,9 @@ +// +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "iqk_config.h" #include "iqk_mul_mat.h" #include "iqk_flash_impl.h" diff --git a/ggml/src/iqk/iqk_flash_impl.h b/ggml/src/iqk/iqk_flash_impl.h index bc68e0d8..68802927 100644 --- a/ggml/src/iqk/iqk_flash_impl.h +++ b/ggml/src/iqk/iqk_flash_impl.h @@ -1,3 +1,9 @@ +// +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #pragma once bool iqk_flash_attn_impl(int type_k, // type of k diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 478bd0de..24db374b 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -1,3 +1,9 @@ +// +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #pragma once #include diff --git a/include/llama.h b/include/llama.h index 38822dbe..3275857a 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2025 The llama.cpp authors +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #ifndef LLAMA_H #define LLAMA_H diff --git a/src/llama-impl.h b/src/llama-impl.h index 95277409..a9cbe0df 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2025 The llama.cpp authors +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #pragma once #define LLAMA_API_INTERNAL diff --git a/src/llama.cpp b/src/llama.cpp index 2ed50b20..d41bd6d8 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1,3 +1,10 @@ +// +// Copyright (C) 2023-2025 The llama.cpp authors +// Copyright (C) 2024-2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "llama-impl.h" #include "llama-vocab.h" #include "llama-grammar.h" From 2309ecda80db24cec11266700d4f9e88977ca750 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 7 Apr 2025 12:39:04 +0200 Subject: [PATCH 076/132] Better iq2_xs quantization (#312) Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-quants.c | 127 ++++++++++++++++++++++++++++++----------- 1 file changed, 93 insertions(+), 34 deletions(-) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index e5a9ec8c..d8f1ddd1 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -13085,6 +13085,12 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict } } +static int iq1_sort_helper(const void * left, const void * right) { + const float * l = left; + const float * r = right; + return *l < *r ? -1 : *l > *r ? 1 : 0; +} + static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) { const int gindex = iq2_data_index(GGML_TYPE_IQ2_XS); @@ -13114,6 +13120,9 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v bool is_on_grid_aux[2]; uint8_t block_signs[2]; uint16_t q2[2*(QK_K/16)]; + uint16_t index[2], aux_index[2]; + float sumx[17], sumw[17], pairs[32]; + int * int_pairs = (int *)(pairs + 1); for (int ibl = 0; ibl < nbl; ++ibl) { @@ -13166,11 +13175,35 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v memset(L, 0, 16); continue; } - float best = 0; - float scale = max/(2*kMaxQ-1); + for (int j = 0; j < 16; ++j) { + pairs[2*j] = xval[j]; + int_pairs[2*j] = j; + } + qsort(pairs, 16, 2*sizeof(float), iq1_sort_helper); + { + sumx[0] = sumw[0] = 0; + for (int j = 0; j < 16; ++j) { + int i = int_pairs[2*j]; + sumx[j+1] = sumx[j] + weight[i]*xval[i]; + sumw[j+1] = sumw[j] + weight[i]; + } + } + float best = 0, scale = 0; + for (int i1 = 0; i1 <= 16; ++i1) { + for (int i2 = i1; i2 <= 16; ++i2) { + float sumqx = (sumx[i1] - sumx[0])*1 + (sumx[i2] - sumx[i1])*3 + (sumx[16] - sumx[i2])*5; + float sumq2 = (sumw[i1] - sumw[0])*1 + (sumw[i2] - sumw[i1])*9 + (sumw[16] - sumw[i2])*25; + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + scale = sumqx/sumq2; best = scale*sumqx; + } + } + } + best = 0; + float eff_max = scale*(2*kMaxQ - 1); is_on_grid[0] = is_on_grid[1] = true; - for (int is = -9; is <= 9; ++is) { - float id = (2*kMaxQ-1+is*0.1f)/max; + index[0] = index[1] = 0; + for (int is = -7; is <= 7; ++is) { + float id = (2*kMaxQ-1+is*0.1f)/eff_max; float this_scale = 1/id; for (int k = 0; k < 2; ++k) { for (int i = 0; i < 8; ++i) { @@ -13186,6 +13219,7 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k); } + aux_index[k] = grid_index; } float sumqx = 0, sumq2 = 0; for (int i = 0; i < 16; ++i) { @@ -13198,35 +13232,45 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v scale = sumqx/sumq2; best = scale*sumqx; for (int i = 0; i < 16; ++i) L[i] = Laux[i]; for (int k = 0; k < 2; ++k) is_on_grid[k] = is_on_grid_aux[k]; + for (int k = 0; k < 2; ++k) index[k] = aux_index[k]; } } - int n_not_ongrid = 0; - for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid; - if (n_not_ongrid > 0 && scale > 0) { - float id = 1/scale; - for (int k = 0; k < 2; ++k) { - if (is_on_grid[k]) continue; - uint16_t u = 0; - for (int i = 0; i < 8; ++i) { - int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); - l = MAX(0, MIN(kMaxQ-1, l)); - u |= (l << 2*i); - L[8*k + i] = l; + if (scale) { + for (int iter = 0; iter < 3; ++iter) { + float id = 1/scale; + bool changed = false; + for (int k = 0; k < 2; ++k) { + uint16_t u = 0; + for (int i = 0; i < 8; ++i) { + int l = nearest_int(0.5f*(id*xval[8*k+i]-1)); + l = MAX(0, MIN(kMaxQ-1, l)); + u |= (l << 2*i); + Laux[8*k + i] = l; + } + int grid_index = kmap_q2xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; + grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, Laux + 8*k); + } + aux_index[k] = grid_index; + if (grid_index != index[k]) changed = true; } - int grid_index = kmap_q2xs[u]; - if (grid_index < 0) { - const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1; - grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k); + if (!changed) break; + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < 16; ++i) { + float w = weight[i]; + float q = 2*Laux[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + scale = sumqx/sumq2; + best = scale*sumqx; + memcpy(L, Laux, 16); + for (int k = 0; k < 2; ++k) index[k] = aux_index[k]; + } + else break; } - float sumqx = 0, sumq2 = 0; - for (int i = 0; i < 16; ++i) { - float w = weight[i]; - float q = 2*L[i] + 1; - sumqx += w*xval[i]*q; - sumq2 += w*q*q; - } - if (sumq2 > 0) scale = sumqx/sumq2; } if (scale < 0) { scale = -scale; @@ -13257,13 +13301,34 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v float d = max_scale/31; y[ibl].d = GGML_FP32_TO_FP16(d); float id = 1/d; + float sumqx = 0, sumq2 = 0; for (int ib = 0; ib < QK_K/16; ++ib) { int l = nearest_int(0.5f*(id*scales[ib]-1)); l = MAX(0, MIN(15, l)); if (ib%2 == 0) y[ibl].scales[ib/2] = l; else y[ibl].scales[ib/2] |= (l << 4); + l = 2*l + 1; + const float * xb = xbl + 16*ib; + if (quant_weights) { + const float * qw = quant_weights + QK_K*ibl + 16*ib; + for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + } else { + for (int i = 0; i < 16; ++i) weight[i] = 0.25f*sigma2 + xb[i]*xb[i]; + } + for (int k = 0; k < 2; ++k) { + int grid_index = q2[2*ib+k] & 511; + const int8_t * grid = (const int8_t *)(iq2xs_grid + grid_index); + const uint8_t signs = ksigns_iq2xs[q2[2*ib+k] >> 9]; + for (int j = 0; j < 8; ++j) { + float w = weight[8*k+j]; + float q = 0.125f*l*grid[j]*(signs & kmask_iq2xs[j] ? -1.f : 1.f); + sumqx += w*q*xb[8*k+j]; + sumq2 += w*q*q; + } + } } memcpy(y[ibl].qs, q2, QK_K/4); + if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(1.05f*sumqx/sumq2); } } @@ -14103,12 +14168,6 @@ static int iq1_find_best_neighbour2(const uint16_t * restrict neighbours, const return grid_index; } -static int iq1_sort_helper(const void * left, const void * right) { - const float * l = left; - const float * r = right; - return *l < *r ? -1 : *l > *r ? 1 : 0; -} - void iq1s_process_1block(int block_size, const float * xb, const float * weight, int8_t * L, float * the_scale, uint16_t * the_index, int * the_shift, float * pairs, float * sumx, float * sumw) { float max = fabsf(xb[0]); From b38759127a996261f9575333451e90a7b18a2252 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 7 Apr 2025 17:25:06 +0200 Subject: [PATCH 077/132] Use links for ggml/llama.cpp authors (#318) * Use links for ggml/llama.cpp authors * This file is not html * More --------- Co-authored-by: Iwan Kawrakow --- AUTHORS | 783 +------------------------------------------------------- LICENSE | 6 +- 2 files changed, 5 insertions(+), 784 deletions(-) diff --git a/AUTHORS b/AUTHORS index 1bd36158..227db9d3 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,782 +1,3 @@ -# date: Wed Jun 26 19:36:34 EEST 2024 -# this file is auto-generated by scripts/gen-authors.sh +Kawrakow +saood06 -0cc4m -0xspringtime <110655352+0xspringtime@users.noreply.github.com> -20kdc -2f38b454 -3ooabkhxtn <31479382+3ooabkhxtn@users.noreply.github.com> -44670 <44670@users.noreply.github.com> -AN Long -AT -Aarni Koskela -Aaron Miller -Aaryaman Vasishta -Abheek Gulati -Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com> -Abhishek Gopinath K <31348521+overtunned@users.noreply.github.com> -Adithya Balaji -AdithyanI -Adrian -Adrian Hesketh -Ahmet Zeer -AidanBeltonS <87009434+AidanBeltonS@users.noreply.github.com> -Aisuko -Akarshan Biswas -Albert Jin -Alberto <57916483+albbus-stack@users.noreply.github.com> -Alex -Alex Azarov -Alex Azarov -Alex Klinkhamer -Alex Klinkhamer -Alex Nguyen -Alex Petenchea -Alex Renda -Alex von Gluck IV -Alexey Parfenov -Ali Chraghi <63465728+alichraghi@users.noreply.github.com> -Ali Nehzat -Ali Tariq -Alon -AlpinDale <52078762+AlpinDale@users.noreply.github.com> -Amir -AmirAli Mirian <37371367+amiralimi@users.noreply.github.com> -Ananta Bastola -Anas Ahouzi <112881240+aahouzi@users.noreply.github.com> -András Salamon -Andrei -Andrew Canis -Andrew Downing -Andrew Duffy -Andrew Godfrey -Andy Tai -Arik Poznanski -Artem -Artem Zinnatullin -Artyom Lebedev -Asbjørn Olling -Ásgeir Bjarni Ingvarsson -Ashish <1856117+ashishdatta@users.noreply.github.com> -Ashok Gelal <401055+ashokgelal@users.noreply.github.com> -Ashraful Islam -Atsushi Tatsuma -Austin <77757836+teleprint-me@users.noreply.github.com> -AustinMroz -BADR -Bach Le -Bailey Chittle <39804642+bachittle@users.noreply.github.com> -BarfingLemurs <128182951+BarfingLemurs@users.noreply.github.com> -Bartowski -Behnam M <58621210+ibehnam@users.noreply.github.com> -Ben Ashbaugh -Ben Garney -Ben Siraphob -Ben Williams -Benjamin Findley <39356821+Kartoffelsaft@users.noreply.github.com> -Benjamin Lecaillon <84293038+blecaillon@users.noreply.github.com> -Bernat Vadell -Bingan <70050083+binganao@users.noreply.github.com> -Bodo Graumann -Bono Lv -Borislav Stanimirov -Branden Butler -Brian -Bruce MacDonald -Bryan Honof -CJ Pais -CRD716 -Calvin Laurenson -Cameron -Cameron Kaiser -Carolinabanana <140120812+Carolinabanana@users.noreply.github.com> -Casey Primozic -Casey Primozic -CausalLM <148736309+CausalLM@users.noreply.github.com> -Cebtenzzre -Chad Brewbaker -Chao Jiang -Cheng Shao -Chris Elrod -Chris Kuehl -Christian Demsar -Christian Demsar -Christian Falch <875252+chrfalch@users.noreply.github.com> -Christian Kögler -Christian Zhou-Zheng <59622928+christianazinn@users.noreply.github.com> -Clark Saben <76020733+csaben@users.noreply.github.com> -Clint Herron -CrispStrobe <154636388+CrispStrobe@users.noreply.github.com> -Cuong Trinh Manh -DAN™ -Damian Stewart -Dane Madsen -DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> -Daniel Bevenius -Daniel Drake -Daniel Hiltgen -Daniel Illescas Romero -Daniele <57776841+daniandtheweb@users.noreply.github.com> -DannyDaemonic -Dat Quoc Nguyen <2412555+datquocnguyen@users.noreply.github.com> -Dave -Dave Airlie -Dave Airlie -Dave Della Costa -David Friehs -David Kennedy -David Pflug -David Renshaw -David Sommers <12738+databyte@users.noreply.github.com> -David Yang -Dawid Potocki -Dawid Wysocki <62249621+TortillaZHawaii@users.noreply.github.com> -Dean -Deins -Deven Mistry <31466137+deven367@users.noreply.github.com> -Didzis Gosko -Djip007 -Don Mahurin -DooWoong Lee (David) -Doomsdayrs <38189170+Doomsdayrs@users.noreply.github.com> -Douglas Hanley -Dr. Tom Murphy VII Ph.D <499244+tom7@users.noreply.github.com> -Ebey Abraham -Ed Lee -Ed Lepedus -Eddie-Wang -Edward Taylor -Elaine -Elbios <141279586+Elbios@users.noreply.github.com> -Elton Kola -Engininja2 <139037756+Engininja2@users.noreply.github.com> -Equim -Eric Sommerlade -Eric Zhang <34133756+EZForever@users.noreply.github.com> -Erik Garrison -Erik Scholz -Ettore Di Giacinto -Evan Jones -Evan Miller -Eve <139727413+netrunnereve@users.noreply.github.com> -Evgeny Kurnevsky -Ewout ter Hoeven -ExtReMLapin <3909752+ExtReMLapin@users.noreply.github.com> -FK -Fabian -Fabio R. Sluzala -Faez Shakil -FantasyGmm <16450052+FantasyGmm@users.noreply.github.com> -Fattire <528174+fat-tire@users.noreply.github.com> -Felix -Finn Voorhees -Firat -Folko-Ven <71110216+Folko-Ven@users.noreply.github.com> -Foul-Tarnished <107711110+Foul-Tarnished@users.noreply.github.com> -Francisco Melo <43780565+francis2tm@users.noreply.github.com> -Frank Mai -FrankHB -Fred Douglas <43351173+fredlas@users.noreply.github.com> -Frederik Vogel -Gabe Goodhart -GainLee -Galunid -Gary Linscott -Gary Mulder -Gavin Zhao -Genkagaku.GPT -Georgi Gerganov -Gilad S -Giuseppe Scrivano -GiviMAD -Govlzkoy -Guillaume "Vermeille" Sanchez -Guillaume Wenzek -Guoteng <32697156+SolenoidWGT@users.noreply.github.com> -Gustavo Rocha Dias <91472747+gustrd@users.noreply.github.com> -Haggai Nuchi -Halalaluyafail3 <55773281+Halalaluyafail3@users.noreply.github.com> -Hamdoud Hakem <90524568+hamdoudhakem@users.noreply.github.com> -HanishKVC -Haohui Mai -Haoxiang Fei -Harald Fernengel -Hatsune Miku <129688334+at8u@users.noreply.github.com> -HatsuneMikuUwU33 <173229399+HatsuneMikuUwU33@users.noreply.github.com> -Henk Poley -Henri Vasserman -Henrik Forstén -Herman Semenov -Hesen Peng -Hoang Nguyen -Hong Bo PENG -Hongyu Ouyang <96765450+casavaca@users.noreply.github.com> -Howard Su -Hua Jiang -Huawei Lin -Hugo Roussel -Ian Bull -Ian Bull -Ian Scrivener -Ido S -IgnacioFDM -Igor Okulist -Ikko Eltociear Ashimine -Ilya Kurdyukov <59548320+ilyakurdyukov@users.noreply.github.com> -Ionoclast Laboratories -Isaac McFadyen -IsaacDynamo <61521674+IsaacDynamo@users.noreply.github.com> -Ivan Komarov -Ivan Stepanov -JH23X <165871467+JH23X@users.noreply.github.com> -Jack Mousseau -JackJollimore <130917767+JackJollimore@users.noreply.github.com> -Jaemin Son -Jag Chadha -Jakub N -James A Capozzoli <157492257+jac-jim@users.noreply.github.com> -James Reynolds -Jan Boon -Jan Boon -Jan Ploski -Jannis Schönleber -Jared Van Bortel -Jared Van Bortel -Jason McCartney -Jean-Christophe Hoelt -Jean-Michaël Celerier -Jed Fox -Jeffrey Quesnelle -Jesse Jojo Johnson -Jeximo -Jhen-Jie Hong -Jiahao Li -Jian Liao -JidongZhang-THU <1119708529@qq.com> -Jinwoo Jeong <33892306+williamjeong2@users.noreply.github.com> -Jiří Podivín <66251151+jpodivin@users.noreply.github.com> -Jiří Sejkora -Joan Fontanals -Joan Fontanals -Johan -Johannes Gäßler -Johannes Rudolph -John <78893154+cmp-nct@users.noreply.github.com> -John Balis -John Smith <67539080+kingsidelee@users.noreply.github.com> -JohnnyB -Jonas Wunderlich <32615971+jonas-w@users.noreply.github.com> -Jorge A <161275481+jorgealias@users.noreply.github.com> -Jose Maldonado <63384398+yukiteruamano@users.noreply.github.com> -Joseph Stahl <1269177+josephst@users.noreply.github.com> -Josh Ramer -Joyce -Juan Calderon-Perez <835733+gaby@users.noreply.github.com> -Judd -Julius Arkenberg -Jun Jie <71215065+junnjiee16@users.noreply.github.com> -Junyang Lin -Juraj Bednar -Justin Parker -Justin Suess -Justina Cho -Justine Tunney -Justine Tunney -Juuso Alasuutari -KASR -Kamil Tomšík -Karsten Weiss -Karthick -Karthik Kumar Viswanathan <195178+guilt@users.noreply.github.com> -Karthik Sethuraman -Kasumi <90275229+kasumi-1@users.noreply.github.com> -Kawrakow <48489457+ikawrakow@users.noreply.github.com> -Keiichi Tabata -Kenvix ⭐ -Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com> -Kevin Gibbons -Kevin Ji <1146876+kevinji@users.noreply.github.com> -Kevin Kwok -Kevin Lo -Kolen Cheung -Konstantin Herud -Konstantin Zhuravlyov -Kunshang Ji -Kyle Liang -Kyle Mistele -Kylin <56434533+KyL0N@users.noreply.github.com> -Lars Grammel -Laura -Lee <44310445+lx200916@users.noreply.github.com> -Lee Drake -Leng Yue -Leon Knauer -LeonEricsson <70749762+LeonEricsson@users.noreply.github.com> -Leonardo Neumann -Li Tan -Linwei Wang -LoganDark -LostRuins <39025047+LostRuins@users.noreply.github.com> -Luciano -Luo Tian -Lyle Dean -M. Yusuf Sarıgöz -Maarten ter Huurne -Mack Straight -Maël Kerbiriou -MaggotHATE -Manuel <44313466+makuche@users.noreply.github.com> -Marc Köhlbrugge -Marco Matthies <71844+marcom@users.noreply.github.com> -Marcus Dunn <51931484+MarcusDunn@users.noreply.github.com> -Marian Cepok -Mark Fairbairn -Marko Tasic -Markus Tavenrath -Martin Delille -Martin Krasser -Martin Schwaighofer -Marvin Gießing -Masaya, Kato <62578291+msy-kato@users.noreply.github.com> -MasterYi1024 <39848311+MasterYi1024@users.noreply.github.com> -Mateusz Charytoniuk -Matheus C. França -Matheus Gabriel Alves Silva -Mathieu Nayrolles -Mathijs de Bruin -Matt Clayton <156335168+mattjcly@users.noreply.github.com> -Matt Pulver -Matteo Boschini <12133566+mbosc@users.noreply.github.com> -Mattheus Chediak -Matthew Tejo -Matvey Soloviev -Max Krasnyansky -Max Krasnyansky -Maxime <672982+maximegmd@users.noreply.github.com> -Maximilian Winter -Meng Zhang -Meng, Hengyu -Merrick Christensen -Michael Coppola -Michael Hueschen -Michael Kesper -Michael Klimenko -Michael Podvitskiy -Michael Potter -Michael de Gans -Michaël de Vries -Mihai -Mike -Mikko Juola -Minsoo Cheong <54794500+mscheong01@users.noreply.github.com> -Mirko185 -Mirror Azure <54669636+MirrorAzure@users.noreply.github.com> -Miwa / Ensan <63481257+ensan-hcl@users.noreply.github.com> -Mohammadreza Hendiani -Mohammadreza Hendiani -Murilo Santana -Musab Gultekin -Nam D. Tran <42194884+namtranase@users.noreply.github.com> -Nathan Epstein -NawafAlansari <72708095+NawafAlansari@users.noreply.github.com> -Nebula -Neo Zhang <14088817+arthw@users.noreply.github.com> -Neo Zhang -Neo Zhang Jianyu -Neuman Vong -Nexesenex <124105151+Nexesenex@users.noreply.github.com> -Niall Coates <1349685+Niall-@users.noreply.github.com> -Nicolai Weitkemper -Nicolás Pérez -Nigel Bosch -Niklas Korz -Nikolas <127742645+nneubacher@users.noreply.github.com> -Nindaleth -Oleksandr Nikitin -Oleksii Maryshchenko -Olivier Chafik -Ondřej Čertík -Ouadie EL FAROUKI -Patrice Ferlet -Paul Tsochantaris -Pavol Rusnak -Pedro Cuenca -Peter Sugihara -Phil H <5756783+phiharri@users.noreply.github.com> -Philip Taron -Phillip Kravtsov -Pierre Alexandre SCHEMBRI -Pierrick Hymbert -Przemysław Pawełczyk -Qin Yue Chen <71813199+chenqiny@users.noreply.github.com> -Qingyou Meng -Qu Zongfu <43257352+yancaoweidaode@users.noreply.github.com> -RJ Adriaansen -Radoslav Gerganov -Radosław Gryta -Rahul Vivek Nair <68507071+RahulVivekNair@users.noreply.github.com> -Raj Hammeer Singh Hada -Ralph Soika -Rand Xie -Randall Fitzgerald -Reinforce-II -Ren Xuancheng -Rene Leonhardt <65483435+reneleonhardt@users.noreply.github.com> -RhinoDevel -Riceball LEE -Richard Kiss -Richard Roberson -Rick G <26732651+TheFlipbook@users.noreply.github.com> -Rickard Edén -Rickard Hallerbäck -Rickey Bowers Jr -Riley Stewart -Rinne -Rinne -Robert Brisita <986796+rbrisita@users.noreply.github.com> -Robert Sung-wook Shin -Robey Holderith -Robyn -Roger Meier -Roland <14355895+rbur0425@users.noreply.github.com> -Romain D <90720+Artefact2@users.noreply.github.com> -Romain Neutron -Roman Parykin -Ron Evans -Ron Jailall -Ronny Brendel -Ronsor -Rowan Hart -Rune <43761327+Rune-AI@users.noreply.github.com> -Ryan Landay -Ryder Wishart -Ryuei -Rőczey Barnabás <31726601+An0nie@users.noreply.github.com> -SakuraUmi -Salvador E. Tropea -Sam Spilsbury -Sami Farin <3876865+Safari77@users.noreply.github.com> -Samuel Maynard -Sang-Kil Park -Seb C <47074056+Sebby37@users.noreply.github.com> -Sebastián A -SebastianApel <13675545+SebastianApel@users.noreply.github.com> -Senemu <10880819+Senemu@users.noreply.github.com> -Sergey Alirzaev -Sergio López -Sertaç Özercan <852750+sozercan@users.noreply.github.com> -SeungWon Jeong <65549245+redlion0929@users.noreply.github.com> -ShadovvBeast -Shakhar Dasgupta -Shangning Xu <32517059+xushangning@users.noreply.github.com> -Shijie <821898965@qq.com> -Shintarou Okada -Shouzheng Liu <61452103+lshzh-ww@users.noreply.github.com> -Shouzheng Liu -Shuichi Tsutsumi -Sigbjørn Skjæret -Simon Willison -Siwen Yu -Sky Yan -Slaren <2141330+slaren@users.noreply.github.com> -Slava Primenko -SoftwareRenderer <138734813+SoftwareRenderer@users.noreply.github.com> -Someone -Someone Serge -Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> -Spencer Sutton -Srihari-mcw <96763064+Srihari-mcw@users.noreply.github.com> -Srinivas Billa -Stefan Sydow -Steffen Röcker -Stephan Walter -Stephen Nichols -Steve Grubb -Steven Prichard -Steven Roussey -Steward Garcia <57494570+FSSRepo@users.noreply.github.com> -Suaj Carrot <72162667+SuajCarrot@users.noreply.github.com> -SuperUserNameMan -Tai Duc Nguyen -Taikono-Himazin -Tameem <113388789+AhmadTameem@users.noreply.github.com> -Tamotsu Takahashi -Thái Hoàng Tâm <75922889+RoyalHeart@users.noreply.github.com> -Thatcher Chamberlin -Theia Vogel -Thérence <13496987+Royalphax@users.noreply.github.com> -Thibault Terrasson -Thomas Klausner -Tim Miller -Timmy Knight -Timothy Cronin <40186632+4imothy@users.noreply.github.com> -Ting Lou -Ting Sun -Tobias Lütke -Tom C -Tom Jobbins <784313+TheBloke@users.noreply.github.com> -Tomas -Tomáš Pazdiora -Tristan Druyen -Tristan Ross -Tungsten842 <886724vf@anonaddy.me> -Tungsten842 -Tushar -UEXTM.com <84163508+uextm@users.noreply.github.com> -Ulrich Drepper -Uzo Nweke -Vaibhav Srivastav -Val Kharitonov -Valentin Konovalov -Valentyn Bezshapkin <61702053+valentynbez@users.noreply.github.com> -Victor Nogueira -Victor Z. Peng -Vlad -Vladimir -Vladimir Malyutin -Vladimir Zorin -Volodymyr Vitvitskyi <72226+signalpillar@users.noreply.github.com> -WangHaoranRobin <56047610+WangHaoranRobin@users.noreply.github.com> -Weird Constructor -Welby Seely -Wentai Zhang -WillCorticesAI <150854901+WillCorticesAI@users.noreply.github.com> -William Tambellini -Willy Tarreau -Wouter <9594229+DifferentialityDevelopment@users.noreply.github.com> -Wu Jian Ping -Wu Jian Ping -Xiake Sun -Xiang (Kevin) Li -Xiao-Yong Jin -XiaotaoChen -Xiaoyi Chen -Xingchen Song(宋星辰) -Xuan Son Nguyen -Yann Follet <131855179+YannFollet@users.noreply.github.com> -Yaroslav -Yazan Agha-Schrader -Yiming Cui -Yishuo Wang -Yueh-Po Peng <94939112+y10ab1@users.noreply.github.com> -Yui -Yusuf Kağan Hanoğlu -Yuval Peled <31162840+Yuval-Peled@users.noreply.github.com> -ZHAOKAI WANG -Zane Shannon -Zay <95888118+isaiahbjork@users.noreply.github.com> -Zenix -Zhang Peiyuan -Zheng.Deng <32841220+dengzheng-cloud@users.noreply.github.com> -ZhouYuChen -Ziad Ben Hadj-Alouane -Ziang Wu <97337387+ZiangWu-77@users.noreply.github.com> -Zsapi -a-n-n-a-l-e-e <150648636+a-n-n-a-l-e-e@users.noreply.github.com> -adel boussaken -afrideva <95653597+afrideva@users.noreply.github.com> -agray3 -akawrykow <142945436+akawrykow@users.noreply.github.com> -alexpinel <93524949+alexpinel@users.noreply.github.com> -alonfaraj -alwqx -amd-lalithnc -andrijdavid -anon998 <131767832+anon998@users.noreply.github.com> -anzz1 -apaz -apcameron <37645737+apcameron@users.noreply.github.com> -arch-btw <57669023+arch-btw@users.noreply.github.com> -arcrank -arlo-phoenix <140345165+arlo-phoenix@users.noreply.github.com> -at8u <129688334+at8u@users.noreply.github.com> -automaticcat -bandoti <141645996+bandoti@users.noreply.github.com> -beiller -bhubbb <79117352+bhubbb@users.noreply.github.com> -bmwl -bobqianic <129547291+bobqianic@users.noreply.github.com> -bryanSwk <93190252+bryanSwk@users.noreply.github.com> -bsilvereagle -bssrdf -byte-6174 <88070277+byte-6174@users.noreply.github.com> -cebtenzzre -chaihahaha -chiranko <96988916+chiranko@users.noreply.github.com> -clibdev <52199778+clibdev@users.noreply.github.com> -clyang -cocktailpeanut <121128867+cocktailpeanut@users.noreply.github.com> -coezbek -comex -compilade <113953597+compilade@users.noreply.github.com> -compilade -cpumaxx <163466046+cpumaxx@users.noreply.github.com> -crasm -crasm -daboe01 -david raistrick -ddh0 -ddpasa <112642920+ddpasa@users.noreply.github.com> -deepdiffuser <112834445+deepdiffuser@users.noreply.github.com> -divinity76 -dm4 -dotpy314 <33351922+dotpy314@users.noreply.github.com> -drbh -ds5t5 <145942675+ds5t5@users.noreply.github.com> -dylan -eastriver -ebraminio -eiery <19350831+eiery@users.noreply.github.com> -eric8607242 -fairydreaming <166155368+fairydreaming@users.noreply.github.com> -fraxy-v <65565042+fraxy-v@users.noreply.github.com> -github-actions[bot] -gliptic -goerch -grahameth <96447521+grahameth@users.noreply.github.com> -gwjr <502526+gwjr@users.noreply.github.com> -h-h-h-h <13482553+h-h-h-h@users.noreply.github.com> -hankcs -hoangmit -hongbo.mo <352280764@qq.com> -hopkins385 <98618192+hopkins385@users.noreply.github.com> -howlger -howlger -hutli <6594598+hutli@users.noreply.github.com> -hutli -hutli -hxer7963 -hydai -iSma -iacore <74560659+iacore@users.noreply.github.com> -igarnier -intelmatt <61025942+intelmatt@users.noreply.github.com> -iohub -jacobi petrucciani <8117202+jpetrucciani@users.noreply.github.com> -jaime-m-p <167997752+jaime-m-p@users.noreply.github.com> -jameswu2014 <545426914@qq.com> -jiez <373447296@qq.com> -jneem -joecryptotoo <80373433+joecryptotoo@users.noreply.github.com> -johnson442 <56517414+johnson442@users.noreply.github.com> -jojorne -jon-chuang <9093549+jon-chuang@users.noreply.github.com> -jp-x-g -jukofyork <69222624+jukofyork@users.noreply.github.com> -junchao-loongson <68935141+junchao-loongson@users.noreply.github.com> -jwj7140 <32943891+jwj7140@users.noreply.github.com> -k.h.lai -kaizau -kalomaze <66376113+kalomaze@users.noreply.github.com> -kang -katsu560 <118887472+katsu560@users.noreply.github.com> -kchro3 <62481661+kchro3@users.noreply.github.com> -khimaros -kiltyj -klosax <131523366+klosax@users.noreply.github.com> -kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> -kunnis -kuronekosaiko -kuvaus <22169537+kuvaus@users.noreply.github.com> -kwin1412 <42286931+kwin1412@users.noreply.github.com> -l3utterfly -ldwang -le.chang -leejet -limitedAtonement -liuwei-git <14815172+liuwei-git@users.noreply.github.com> -lon <114724657+longregen@users.noreply.github.com> -loonerin <132926317+loonerin@users.noreply.github.com> -luoyu-intel -m3ndax -maddes8cht <55592906+maddes8cht@users.noreply.github.com> -makomk -manikbhandari -maor-ps <154728172+maor-ps@users.noreply.github.com> -mdrokz -mgroeber9110 <45620825+mgroeber9110@users.noreply.github.com> -minarchist -mj-shifu <77107165+mj-shifu@users.noreply.github.com> -mmyjona -momonga <115213907+mmnga@users.noreply.github.com> -moritzbrantner <31051084+moritzbrantner@users.noreply.github.com> -mzcu -nanahi <130121847+na-na-hi@users.noreply.github.com> -ngc92 <7938269+ngc92@users.noreply.github.com> -nhamanasu <45545786+nhamanasu@users.noreply.github.com> -niansa/tuxifan -niansa/tuxifan -nickp27 -ningshanwutuobang -nold -nopperl <54780682+nopperl@users.noreply.github.com> -nusu-github <29514220+nusu-github@users.noreply.github.com> -olexiyb -omahs <73983677+omahs@users.noreply.github.com> -oobabooga <112222186+oobabooga@users.noreply.github.com> -opparco -ostix360 <55257054+ostix360@users.noreply.github.com> -pengxin99 -perserk -pmysl -postmasters -pudepiedj -qingfengfenga <41416092+qingfengfenga@users.noreply.github.com> -qouoq -qunash -rabidcopy -rankaiyx -rhjdvsgsgks <26178113+rhjdvsgsgks@users.noreply.github.com> -rhuddleston -rimoliga <53384203+rimoliga@users.noreply.github.com> -runfuture -sandyiscool -sasha0552 -semidark -sharpHL <132747147+sharpHL@users.noreply.github.com> -shibe2 -singularity <12184989+singularity-s0@users.noreply.github.com> -sjinzh -sjxx <63994076+ylsdamxssjxxdd@users.noreply.github.com> -slaren <2141330+slaren@users.noreply.github.com> -slaren -snadampal <87143774+snadampal@users.noreply.github.com> -staviq -stduhpf -strawberrymelonpanda <152940198+strawberrymelonpanda@users.noreply.github.com> -swittk -takov751 <40316768+takov751@users.noreply.github.com> -tarcey -texmex76 <40733439+texmex76@users.noreply.github.com> -thement <40525767+thement@users.noreply.github.com> -tjohnman -tslmy -ubik2 -uint256_t -uint256_t -unbounded -valiray <133289098+valiray@users.noreply.github.com> -vik -viric -vodkaslime <646329483@qq.com> -vvhg1 <94630311+vvhg1@users.noreply.github.com> -vxiiduu <73044267+vxiiduu@users.noreply.github.com> -wbpxre150 <100937007+wbpxre150@users.noreply.github.com> -whoreson <139810751+whoreson@users.noreply.github.com> -woachk <24752637+woachk@users.noreply.github.com> -wonjun Jang -woodx <124784234+woodx9@users.noreply.github.com> -wzy <32936898+Freed-Wu@users.noreply.github.com> -xaedes -xaedes -xloem <0xloem@gmail.com> -yangli2 -yuiseki -zakkor -zhangkaihuo -zhouwg <6889919+zhouwg@users.noreply.github.com> -zhouwg -zrm -Ștefan-Gabriel Muscalu -源文雨 <41315874+fumiama@users.noreply.github.com> -Нияз Гарифзянов <112617865+garrnizon@users.noreply.github.com> diff --git a/LICENSE b/LICENSE index 34fbbe73..be8d3b11 100644 --- a/LICENSE +++ b/LICENSE @@ -1,8 +1,8 @@ MIT License -Copyright (c) 2023-2024 The ggml authors -Copyright (c) 2023-2024 The llama.cpp authors -Copyright (c) 2024-2025 Iwan Kawrakow +Copyright (c) 2023-2024 The ggml authors (https://github.com/ggml-org/ggml/blob/master/AUTHORS) +Copyright (c) 2023-2024 The llama.cpp authors (https://github.com/ggml-org/llama.cpp/blob/master/AUTHORS) +Copyright (c) 2024-2025 The ik_llama.cpp authors (https://github.com/ikawrakow/ik_llama.cpp/blob/main/AUTHORS) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal From f03ae19aade46a66c0f796e00de183780e104d93 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 7 Apr 2025 17:28:19 +0200 Subject: [PATCH 078/132] Update AUTHORS Forgot to add @Nexesenex --- AUTHORS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AUTHORS b/AUTHORS index 227db9d3..15a25569 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,3 +1,3 @@ Kawrakow saood06 - +Nexes the Elder <124105151+Nexesenex@users.noreply.github.com> From 22d7440ba28000874b571b4a44a8b2c39c9a5ba8 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 7 Apr 2025 17:31:35 +0200 Subject: [PATCH 079/132] Update AUTHORS Well, there was also the initial MLA PR, which was derived from @fairydreaming --- AUTHORS | 2 ++ 1 file changed, 2 insertions(+) diff --git a/AUTHORS b/AUTHORS index 15a25569..035292f1 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,3 +1,5 @@ Kawrakow saood06 Nexes the Elder <124105151+Nexesenex@users.noreply.github.com> +fairydreaming <166155368+fairydreaming@users.noreply.github.com> +Stanisław Szymczyk From 5f44f4b3d006a24267ea02fe65490bb760a01447 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 8 Apr 2025 08:47:24 +0200 Subject: [PATCH 080/132] Guard against attempts to use MLA for non-MLA models (#320) Co-authored-by: Iwan Kawrakow --- src/llama.cpp | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index d41bd6d8..5f4642fb 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3245,13 +3245,15 @@ static bool llama_kv_cache_init( cache.ctxs.push_back(ctx); } - cache.k_l.reserve(n_layer); - cache.v_l.reserve(n_layer); - - // DeepSeek MLA - cache.kv_l.reserve(n_layer); - if (cparams.mla_attn == 1 && !cparams.flash_attn) { - cache.kvt_l.reserve(n_layer); + if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn) { + // DeepSeek MLA + cache.kv_l.reserve(n_layer); + if (cparams.mla_attn == 1 && !cparams.flash_attn) { + cache.kvt_l.reserve(n_layer); + } + } else { + cache.k_l.reserve(n_layer); + cache.v_l.reserve(n_layer); } bool warn = true; @@ -3299,7 +3301,7 @@ static bool llama_kv_cache_init( cache.v_l.push_back(v); } } - if (cparams.mla_attn && n_mla < n_layer && n_mla > 0) { + if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn && n_mla < n_layer && n_mla > 0) { LLAMA_LOG_ERROR("%s: unexpected situation with %d out of %d layers having MLA enabled\n", __func__, n_mla, int(n_layer)); LLAMA_LOG_ERROR("%s: bailing out\n", __func__); GGML_ABORT("fatal error"); @@ -18568,6 +18570,13 @@ struct llama_context * llama_new_context_with_model( params.seed = time(NULL); } + if (model->arch != LLM_ARCH_DEEPSEEK2 && cparams.mla_attn > 0) { + LLAMA_LOG_WARN("=====================================================================\n"); + LLAMA_LOG_WARN(" MLA is only available for LLM_ARCH_DEEPSEEK2 -> turning off MLA\n"); + LLAMA_LOG_WARN("=====================================================================\n"); + cparams.mla_attn = 0; + } + LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); From 474435f58b6a26bc549589966482207fee94aa60 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 10 Apr 2025 09:05:21 +0200 Subject: [PATCH 081/132] LlaMA-4 support (text only) (#321) * llama4: WIP * llama4: this seems to be working --------- Co-authored-by: Iwan Kawrakow --- include/llama.h | 7 +- src/llama-vocab.cpp | 21 ++++ src/llama.cpp | 288 +++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 295 insertions(+), 21 deletions(-) diff --git a/include/llama.h b/include/llama.h index 3275857a..d7376d7d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -100,7 +100,12 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 23, //llama.cpp lists this as 28 + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, //llama.cpp lists this as 28 + LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, + LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, + LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, + LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, + LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, }; // note: these values should be synchronized with ggml_rope diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 4bd5aa81..09399417 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -439,6 +439,27 @@ struct llm_tokenizer_bpe { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_GPT4O: + regex_exprs = { + // original regex from tokenizer.json + // "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_SUPERBPE: + regex_exprs = { + "\\p{N}+", + "(?=(\\d{3})+(?!\\d))", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_BAILINGMOE: + regex_exprs = { + // original regex from tokenizer.json + // "'(?i:[sdmt]|ll|ve|re)|[^\\r\\n\\p{L}\\p{N}]?+\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]++[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+" + // FIXME? Changed possessive quantifiers (?+ and ++) to greedy to avoid errors and imatrix hanging (tried atomic grouping but it's not supported?) + "'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { diff --git a/src/llama.cpp b/src/llama.cpp index 5f4642fb..5565ffcd 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -184,6 +184,7 @@ static std::string format(const char * fmt, ...) { enum llm_arch { LLM_ARCH_LLAMA, + LLM_ARCH_LLAMA4, LLM_ARCH_FALCON, LLM_ARCH_BAICHUAN, LLM_ARCH_GROK, @@ -232,6 +233,7 @@ enum llm_arch { static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_LLAMA4, "llama4" }, { LLM_ARCH_FALCON, "falcon" }, { LLM_ARCH_GROK, "grok" }, { LLM_ARCH_GPT2, "gpt2" }, @@ -319,6 +321,8 @@ enum llm_kv { LLM_KV_TIME_DECAY_EXTRA_DIM, LLM_KV_RESIDUAL_SCALE, LLM_KV_EMBEDDING_SCALE, + LLM_KV_TOKEN_SHIFT_COUNT, + LLM_KV_INTERLEAVE_MOE_LAYER_STEP, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -422,6 +426,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, { LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" }, { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, + { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" }, + { LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -614,6 +620,35 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_LLAMA4, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, { LLM_ARCH_BAICHUAN, { @@ -1432,6 +1467,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_GRANITE, LLM_CHAT_TEMPLATE_GIGACHAT, LLM_CHAT_TEMPLATE_MEGREZ, + LLM_CHAT_TEMPLATE_LLAMA4, LLM_CHAT_TEMPLATE_UNKNOWN, }; @@ -1467,6 +1503,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "granite", LLM_CHAT_TEMPLATE_GRANITE }, { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, { "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, + { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, }; @@ -2357,6 +2394,8 @@ enum e_model { MODEL_10B_128x3_66B, MODEL_57B_A14B, MODEL_27B, + MODEL_17B_16E, + MODEL_17B_128E, }; static const size_t kiB = 1024; @@ -2444,6 +2483,14 @@ struct llama_hparams { bool use_alibi = false; bool attn_soft_cap = false; + uint32_t n_moe_layer_step = 0; + bool use_kq_norm = true; + uint32_t n_attn_chunk = 0; + // values below seems to be fixed on llama4 + uint32_t n_no_rope_layer_step = 4; + uint32_t n_attn_temp_floor_scale = 8192; + float f_attn_temp_scale = 0.1; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 llama_token dec_start_token_id = -1; @@ -2939,6 +2986,8 @@ struct llama_context { struct llama_kv_cache kv_self; struct llama_control_vector cvec; + std::vector scale_data; + std::unordered_map lora_adapters; std::vector backends; @@ -3015,6 +3064,7 @@ struct llama_context { struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] + struct ggml_tensor * inp_scale = nullptr; // F32 [n_tokens] }; struct llama_lora_weight { @@ -4945,6 +4995,8 @@ static const char * llama_model_type_name(e_model type) { case MODEL_10B_128x3_66B: return "10B+128x3.66B"; case MODEL_57B_A14B: return "57B.A14B"; case MODEL_27B: return "27B"; + case MODEL_17B_16E: return "17Bx16E (Scout)"; + case MODEL_17B_128E: return "17Bx128E (Maverick)"; default: return "?B"; } } @@ -5106,6 +5158,25 @@ static void llm_load_hparams( } } } break; + case LLM_ARCH_LLAMA4: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); + hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full + hparams.n_attn_chunk = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick + hparams.n_swa = 1; // TODO @ngxson : this is added to trigger the SWA branch (we store the chunked attn mask in the SWA tensor), will need to clean this up later + + switch (hparams.n_expert) { + case 16: model.type = MODEL_17B_16E; break; + case 128: model.type = MODEL_17B_128E; break; + default: model.type = MODEL_UNKNOWN; + } + + if (model.type == MODEL_17B_128E) { + hparams.use_kq_norm = false; + } + } break; case LLM_ARCH_MINICPM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5929,6 +6000,23 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "codeshell") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CODESHELL; + } else if ( + tokenizer_pre == "gpt-4o" || + tokenizer_pre == "llama4") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT4O; + vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "superbpe") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SUPERBPE; + vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "trillion") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TRILLION; + vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "bailingmoe") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE; + vocab.tokenizer_clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -6552,6 +6640,7 @@ static bool llm_load_tensors( const int64_t n_embd_gqa = n_embd_v_gqa; const int64_t n_vocab = hparams.n_vocab; const int64_t n_vocab_type = hparams.n_vocab_type; + const int64_t n_rot = hparams.n_rot; const int64_t n_expert = hparams.n_expert; const int64_t n_expert_used = hparams.n_expert_used; const int64_t n_ctx_train = hparams.n_ctx_train; @@ -6670,6 +6759,61 @@ static bool llm_load_tensors( } } } break; + case LLM_ARCH_LLAMA4: + { + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, + llama_model_loader::TENSOR_NOT_REQUIRED); + + // output + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, + llama_model_loader::TENSOR_DUPLICATED); + } + + GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Llama 4 requires n_moe_layer_step > 0"); + for (int i = 0; i < n_layer; ++i) { + bool is_moe_layer = (i + 1) % hparams.n_moe_layer_step == 0; + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.rope_freqs = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, + llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + + if (is_moe_layer) { + int n_ff_exp = hparams.n_ff_exp; + + layer.ffn_gate_inp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert + const int64_t n_ff_shexp = n_ff_exp; + layer.ffn_gate_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd }, 0); + layer.ffn_up_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + } else { + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } + } break; case LLM_ARCH_GROK: { if (n_expert == 0) { @@ -8623,8 +8767,6 @@ static void llm_build_kv_store( // note: storing RoPE-ed version of K in the KV cache ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); - assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); - struct ggml_tensor * v_cache_view = nullptr; if (cparams.flash_attn) { @@ -8884,6 +9026,7 @@ llm_expert_gating_func_type gating_op, int il) { int64_t n_embd = cur->ne[0]; int64_t n_tokens = cur->ne[1]; + bool weight_before_ffn = lctx.model.arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens] cb(logits, "ffn_moe_logits", il); @@ -8912,6 +9055,12 @@ llm_expert_gating_func_type gating_op, cb(selection_probs, "ffn_moe_probs_biased", il); } + // llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k + // see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198 + if (lctx.model.arch == LLM_ARCH_LLAMA4) { + selection_probs = logits; + } + // select experts ggml_tensor * selected_experts = ggml_top_k_thresh(ctx, selection_probs, n_expert_used, lctx.cparams.min_experts, lctx.cparams.thresh_experts); // [n_expert_used, n_tokens] @@ -8940,6 +9089,14 @@ llm_expert_gating_func_type gating_op, cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens); + if (weight_before_ffn) { + // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d) + ggml_tensor * repeated = ggml_new_tensor_3d(ctx, cur->type, n_embd, n_expert_used, n_tokens); + repeated = ggml_repeat(ctx, cur, repeated); // [n_embd, n_expert_used, n_tokens] + cur = ggml_mul(ctx, repeated, weights); + cb(cur, "ffn_moe_weighted", il); + } + ggml_tensor * par; if (lctx.cparams.fused_moe_up_gate) { par = ggml_moe_up_gate(ctx, up_exps, gate_exps, cur, selected_experts, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU); @@ -8958,7 +9115,10 @@ llm_expert_gating_func_type gating_op, ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); - experts = ggml_mul(ctx, experts, weights); + if (!weight_before_ffn) { + experts = ggml_mul(ctx, experts, weights); + cb(cur, "ffn_moe_weighted", il); + } if (n_expert_used == 1) { return ggml_cont(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0)); @@ -9487,6 +9647,14 @@ struct llm_build_context { return lctx.inp_pos; } + struct ggml_tensor * build_inpup_scale(int n_tokens) { + int n_pos_per_token = 1; + lctx.inp_scale = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token); + cb(lctx.inp_scale, "inp_scale", -1); + ggml_set_input(lctx.inp_scale); + return lctx.inp_scale; + } + struct ggml_tensor * build_rope_factors(int il) { // choose long/short freq factors based on the context size const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max; @@ -9667,14 +9835,19 @@ struct llm_build_context { GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); - struct ggml_tensor * cur; - struct ggml_tensor * inpL; + ggml_tensor * cur; + ggml_tensor * inpL; + ggml_tensor * inp_attn_scale = nullptr; inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); + if (model.arch == LLM_ARCH_LLAMA4) { + inp_attn_scale = build_inpup_scale(n_tokens); + } + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); @@ -9683,6 +9856,8 @@ struct llm_build_context { for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; + bool use_rope = model.arch == LLM_ARCH_LLAMA4 ? (il + 1) % hparams.n_no_rope_layer_step != 0 : true; + // norm cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, @@ -9720,19 +9895,29 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_ext( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - cb(Qcur, "Qcur", il); + if (use_rope) { + Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); - Kcur = ggml_rope_ext( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + } else if (inp_attn_scale) { + Qcur = ggml_mul(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_attn_scale); + } + + cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + if (model.arch == LLM_ARCH_LLAMA4 && use_rope && hparams.use_kq_norm) { + // Llama4TextL2Norm + Qcur = ggml_rms_norm(ctx0, Qcur, 1e-6); + Kcur = ggml_rms_norm(ctx0, Kcur, 1e-6); + cb(Qcur, "Qcur_normed", il); + cb(Kcur, "Kcur_normed", il); + } cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, @@ -9758,6 +9943,7 @@ struct llm_build_context { // feed-forward network if (model.layers[il].ffn_gate_inp == nullptr) { + // non-MoE cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); @@ -9770,6 +9956,37 @@ struct llm_build_context { NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); + } else if (model.arch == LLM_ARCH_LLAMA4) { + // llama4 MoE + ggml_tensor * ffn_inp_normed = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = llm_build_moe_ffn(ctx0, lctx, ffn_inp_normed, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + false, 0.0, + LLM_EXPERT_GATING_FUNC_SIGMOID, + cb, il); + + // Shared experts + ggml_tensor * shexp_out = llm_build_ffn(ctx0, lctx, ffn_inp_normed, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(shexp_out, "ffn_moe_shexp", il); + + cur = ggml_add(ctx0, moe_out, shexp_out); + cb(cur, "ffn_moe_out_merged", il); + } else { // MoE branch cur = llm_build_norm(ctx0, ffn_inp, hparams, @@ -9782,11 +9999,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, - nullptr, + nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, - LLM_EXPERT_GATING_FUNC_SOFTMAX, + LLM_EXPERT_GATING_FUNC_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); } @@ -15162,6 +15379,7 @@ static struct ggml_cgraph * llama_build_graph( switch (model.arch) { case LLM_ARCH_LLAMA: + case LLM_ARCH_LLAMA4: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: { @@ -15423,6 +15641,17 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); } + if (lctx.inp_pos && lctx.inp_scale) { + int n_tokens = batch.n_tokens; + GGML_ASSERT(ggml_nelements(lctx.inp_scale) >= n_tokens); + if (int(lctx.scale_data.size()) < n_tokens) lctx.scale_data.resize(n_tokens); + int n_pos_per_token = 1; + for (int i = 0; i < n_tokens; ++i) { + lctx.scale_data[i] = std::log(std::floor((batch.pos[i] + 1.0f) / hparams.n_attn_temp_floor_scale) + 1.0f) * hparams.f_attn_temp_scale + 1.0f; + } + ggml_backend_tensor_set(lctx.inp_scale, lctx.scale_data.data(), 0, n_tokens*n_pos_per_token*ggml_element_size(lctx.inp_scale)); + } + if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); const int64_t n_tokens = batch.n_tokens; @@ -15504,8 +15733,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { // may need to cut off old tokens for sliding window if (data_swa) { - if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { - f = -INFINITY; + if (hparams.n_attn_chunk) { + llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk; + if (lctx.kv_self.cells[i].pos < pos_chunk_start || pos < pos_chunk_start) { + f = -INFINITY; + } + } else { + if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { + f = -INFINITY; + } } data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f; } @@ -18949,6 +19185,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { // use what we call a normal RoPE, operating on pairs of consecutive head values case LLM_ARCH_LLAMA: + case LLM_ARCH_LLAMA4: case LLM_ARCH_BAICHUAN: case LLM_ARCH_STARCODER: case LLM_ARCH_PLAMO: @@ -20774,6 +21011,8 @@ static llm_chat_template llama_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_GIGACHAT; } else if (tmpl_contains("<|role_start|>")) { return LLM_CHAT_TEMPLATE_MEGREZ; + } else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) { + return LLM_CHAT_TEMPLATE_LLAMA4; } return LLM_CHAT_TEMPLATE_UNKNOWN; } @@ -21155,6 +21394,15 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|role_start|>assistant<|role_end|>"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA4) { + // Llama 4 + for (auto message : chat) { + std::string role(message->role); + ss << "<|header_start|>" << role << "<|header_end|>\n\n" << trim(message->content) << "<|eot|>"; + } + if (add_ass) { + ss << "<|header_start|>assistant<|header_end|>\n\n"; + } } else { // template not supported return -1; From 3e7be3d28e626b52e6198cfc8a20c77a262cfeb9 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 11 Apr 2025 10:49:18 +0200 Subject: [PATCH 082/132] Correct L4 rms_norm (#324) Co-authored-by: Iwan Kawrakow --- src/llama.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 5565ffcd..5340d847 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9913,8 +9913,8 @@ struct llm_build_context { if (model.arch == LLM_ARCH_LLAMA4 && use_rope && hparams.use_kq_norm) { // Llama4TextL2Norm - Qcur = ggml_rms_norm(ctx0, Qcur, 1e-6); - Kcur = ggml_rms_norm(ctx0, Kcur, 1e-6); + Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps); + Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps); cb(Qcur, "Qcur_normed", il); cb(Kcur, "Kcur_normed", il); } From c01449a47869d33f3a2c4272306624ba9d025c3f Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 12 Apr 2025 16:17:50 +0200 Subject: [PATCH 083/132] Fix KLD precision (#325) Co-authored-by: Iwan Kawrakow --- examples/perplexity/perplexity.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 07ac88af..12702693 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -133,7 +133,7 @@ static double log_softmax(int n_vocab, const float * logits, uint16_t * log_prob max_logit = std::max(max_logit, logits[i]); min_logit = std::min(min_logit, logits[i]); } - min_logit = std::max(min_logit, max_logit - 16); + min_logit = std::max(min_logit, max_logit - 24); double sum_exp = 0.0; for (int i = 0; i < n_vocab; ++i) { sum_exp += expf(logits[i] - max_logit); From d210661c9121b2e9d21cd8fbe7d6a536b61453d5 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 13 Apr 2025 10:37:55 +0200 Subject: [PATCH 084/132] Improved IQ1_M quantization (#327) * Much faster and it looks like better iq1_m quantiation * Cleanup * Minor --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-quants.c | 168 ++++++++++++++++++++--------------------- 1 file changed, 83 insertions(+), 85 deletions(-) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index d8f1ddd1..cc1c8fc6 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -14391,6 +14391,8 @@ void iq1m_process_1block(const float * xb, const float * weight, int8_t * L, flo const float x_m[3] = {-1 - IQ1M_DELTA, -IQ1M_DELTA, 1 - IQ1M_DELTA}; float sumqx[4], sumq2[4]; + float sumw1[IQ1M_BLOCK_SIZE+1], sumw2[IQ1M_BLOCK_SIZE+1]; + float sumx1[IQ1M_BLOCK_SIZE+1], sumx2[IQ1M_BLOCK_SIZE+1]; const int gindex = iq2_data_index(GGML_TYPE_IQ1_M); @@ -14414,7 +14416,22 @@ void iq1m_process_1block(const float * xb, const float * weight, int8_t * L, flo idx[2*j] = j; } qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper); - float best_score = -FLT_MIN, scale = 0.f; + sumw1[0] = sumw2[0] = sumx1[0] = sumx2[0] = 0; + for (int j = 0; j < block_size; ++j) { + int i = idx[2*j]; + if (i < block_size/2) { + sumw1[j+1] = sumw1[j] + weight[i]; + sumx1[j+1] = sumx1[j] + weight[i]*xb[i]; + sumw2[j+1] = sumw2[j]; + sumx2[j+1] = sumx2[j]; + } else { + sumw2[j+1] = sumw2[j] + weight[i]; + sumx2[j+1] = sumx2[j] + weight[i]*xb[i]; + sumw1[j+1] = sumw1[j]; + sumx1[j+1] = sumx1[j]; + } + } + float best_score = 0, scale = 0.f; int besti1 = -1, besti2 = -1, best_k = -1; // 0: +, + // 1: +, - @@ -14422,74 +14439,22 @@ void iq1m_process_1block(const float * xb, const float * weight, int8_t * L, flo // 3: -, - for (int i1 = 0; i1 <= block_size; ++i1) { for (int i2 = i1; i2 <= block_size; ++i2) { - memset(sumqx, 0, 4*sizeof(float)); - memset(sumq2, 0, 4*sizeof(float)); - for (int j = 0; j < i1; ++j) { - int i = idx[2*j]; - if (i < block_size/2) { - sumqx[0] += weight[i]*x_p[0]*xb[i]; - sumqx[1] += weight[i]*x_p[0]*xb[i]; - sumqx[2] += weight[i]*x_m[0]*xb[i]; - sumqx[3] += weight[i]*x_m[0]*xb[i]; - sumq2[0] += weight[i]*x_p[0]*x_p[0]; - sumq2[1] += weight[i]*x_p[0]*x_p[0]; - sumq2[2] += weight[i]*x_m[0]*x_m[0]; - sumq2[3] += weight[i]*x_m[0]*x_m[0]; - } else { - sumqx[0] += weight[i]*x_p[0]*xb[i]; - sumqx[2] += weight[i]*x_p[0]*xb[i]; - sumqx[1] += weight[i]*x_m[0]*xb[i]; - sumqx[3] += weight[i]*x_m[0]*xb[i]; - sumq2[0] += weight[i]*x_p[0]*x_p[0]; - sumq2[2] += weight[i]*x_p[0]*x_p[0]; - sumq2[1] += weight[i]*x_m[0]*x_m[0]; - sumq2[3] += weight[i]*x_m[0]*x_m[0]; - } - } - for (int j = i1; j < i2; ++j) { - int i = idx[2*j]; - if (i < block_size/2) { - sumqx[0] += weight[i]*x_p[1]*xb[i]; - sumqx[1] += weight[i]*x_p[1]*xb[i]; - sumqx[2] += weight[i]*x_m[1]*xb[i]; - sumqx[3] += weight[i]*x_m[1]*xb[i]; - sumq2[0] += weight[i]*x_p[1]*x_p[1]; - sumq2[1] += weight[i]*x_p[1]*x_p[1]; - sumq2[2] += weight[i]*x_m[1]*x_m[1]; - sumq2[3] += weight[i]*x_m[1]*x_m[1]; - } else { - sumqx[0] += weight[i]*x_p[1]*xb[i]; - sumqx[2] += weight[i]*x_p[1]*xb[i]; - sumqx[1] += weight[i]*x_m[1]*xb[i]; - sumqx[3] += weight[i]*x_m[1]*xb[i]; - sumq2[0] += weight[i]*x_p[1]*x_p[1]; - sumq2[2] += weight[i]*x_p[1]*x_p[1]; - sumq2[1] += weight[i]*x_m[1]*x_m[1]; - sumq2[3] += weight[i]*x_m[1]*x_m[1]; - } - } - for (int j = i2; j < block_size; ++j) { - int i = idx[2*j]; - if (i < block_size/2) { - sumqx[0] += weight[i]*x_p[2]*xb[i]; - sumqx[1] += weight[i]*x_p[2]*xb[i]; - sumqx[2] += weight[i]*x_m[2]*xb[i]; - sumqx[3] += weight[i]*x_m[2]*xb[i]; - sumq2[0] += weight[i]*x_p[2]*x_p[2]; - sumq2[1] += weight[i]*x_p[2]*x_p[2]; - sumq2[2] += weight[i]*x_m[2]*x_m[2]; - sumq2[3] += weight[i]*x_m[2]*x_m[2]; - } else { - sumqx[0] += weight[i]*x_p[2]*xb[i]; - sumqx[2] += weight[i]*x_p[2]*xb[i]; - sumqx[1] += weight[i]*x_m[2]*xb[i]; - sumqx[3] += weight[i]*x_m[2]*xb[i]; - sumq2[0] += weight[i]*x_p[2]*x_p[2]; - sumq2[2] += weight[i]*x_p[2]*x_p[2]; - sumq2[1] += weight[i]*x_m[2]*x_m[2]; - sumq2[3] += weight[i]*x_m[2]*x_m[2]; - } - } + sumqx[0] = (sumx1[i1] - sumx1[0])*x_p[0] + (sumx1[i2] - sumx1[i1])*x_p[1] + (sumx1[block_size]-sumx1[i2])*x_p[2] + + (sumx2[i1] - sumx2[0])*x_p[0] + (sumx2[i2] - sumx2[i1])*x_p[1] + (sumx2[block_size]-sumx2[i2])*x_p[2]; + sumqx[1] = (sumx1[i1] - sumx1[0])*x_p[0] + (sumx1[i2] - sumx1[i1])*x_p[1] + (sumx1[block_size]-sumx1[i2])*x_p[2] + + (sumx2[i1] - sumx2[0])*x_m[0] + (sumx2[i2] - sumx2[i1])*x_m[1] + (sumx2[block_size]-sumx2[i2])*x_m[2]; + sumqx[2] = (sumx1[i1] - sumx1[0])*x_m[0] + (sumx1[i2] - sumx1[i1])*x_m[1] + (sumx1[block_size]-sumx1[i2])*x_m[2] + + (sumx2[i1] - sumx2[0])*x_p[0] + (sumx2[i2] - sumx2[i1])*x_p[1] + (sumx2[block_size]-sumx2[i2])*x_p[2]; + sumqx[3] = (sumx1[i1] - sumx1[0])*x_m[0] + (sumx1[i2] - sumx1[i1])*x_m[1] + (sumx1[block_size]-sumx1[i2])*x_m[2] + + (sumx2[i1] - sumx2[0])*x_m[0] + (sumx2[i2] - sumx2[i1])*x_m[1] + (sumx2[block_size]-sumx2[i2])*x_m[2]; + sumq2[0] = (sumw1[i1] - sumw1[0])*x_p[0]*x_p[0] + (sumw1[i2] - sumw1[i1])*x_p[1]*x_p[1] + (sumw1[block_size]-sumw1[i2])*x_p[2]*x_p[2] + + (sumw2[i1] - sumw2[0])*x_p[0]*x_p[0] + (sumw2[i2] - sumw2[i1])*x_p[1]*x_p[1] + (sumw2[block_size]-sumw2[i2])*x_p[2]*x_p[2]; + sumq2[1] = (sumw1[i1] - sumw1[0])*x_p[0]*x_p[0] + (sumw1[i2] - sumw1[i1])*x_p[1]*x_p[1] + (sumw1[block_size]-sumw1[i2])*x_p[2]*x_p[2] + + (sumw2[i1] - sumw2[0])*x_m[0]*x_m[0] + (sumw2[i2] - sumw2[i1])*x_m[1]*x_m[1] + (sumw2[block_size]-sumw2[i2])*x_m[2]*x_m[2]; + sumq2[2] = (sumw1[i1] - sumw1[0])*x_m[0]*x_m[0] + (sumw1[i2] - sumw1[i1])*x_m[1]*x_m[1] + (sumw1[block_size]-sumw1[i2])*x_m[2]*x_m[2] + + (sumw2[i1] - sumw2[0])*x_p[0]*x_p[0] + (sumw2[i2] - sumw2[i1])*x_p[1]*x_p[1] + (sumw2[block_size]-sumw2[i2])*x_p[2]*x_p[2]; + sumq2[3] = (sumw1[i1] - sumw1[0])*x_m[0]*x_m[0] + (sumw1[i2] - sumw1[i1])*x_m[1]*x_m[1] + (sumw1[block_size]-sumw1[i2])*x_m[2]*x_m[2] + + (sumw2[i1] - sumw2[0])*x_m[0]*x_m[0] + (sumw2[i2] - sumw2[i1])*x_m[1]*x_m[1] + (sumw2[block_size]-sumw2[i2])*x_m[2]*x_m[2]; for (int k = 0; k < 4; ++k) { if (sumq2[k] > 0 && sumqx[k]*sumqx[k] > best_score*sumq2[k]) { scale = sumqx[k]/sumq2[k]; best_score = scale*sumqx[k]; @@ -14524,19 +14489,34 @@ void iq1m_process_1block(const float * xb, const float * weight, int8_t * L, flo the_index[k] = grid_index; } if (!all_on_grid) { - float sumqx_f = 0, sumq2_f = 0; - for (int k = 0; k < block_size/8; ++k) { - if (k == 0) xx = best_k < 2 ? x_p : x_m; - else xx = best_k%2 == 0 ? x_p : x_m; - const int8_t * pg = (const int8_t *)(kgrid_q2xs + the_index[k]); - for (int j = 0; j < 8; ++j) { - float w = weight[8*k + j]; - float q = xx[(pg[j] - 1)/2]; - sumqx_f += w*q*xb[8*k+j]; - sumq2_f += w*q*q; + sumqx[0] = sumqx[1] = sumqx[2] = sumqx[3] = 0; + sumq2[0] = sumq2[1] = sumq2[2] = sumq2[3] = 0; + for (int j = 0; j < block_size; ++j) { + float w = weight[j]; + float qp = x_p[L[j]]; + float qm = x_m[L[j]]; + sumqx[0] += w*xb[j]*qp; + sumq2[0] += w*qp*qp; + sumqx[3] += w*xb[j]*qm; + sumq2[3] += w*qm*qm; + if (j < 8) { + sumqx[1] += w*xb[j]*qp; + sumq2[1] += w*qp*qp; + sumqx[2] += w*xb[j]*qm; + sumq2[2] += w*qm*qm; + } else { + sumqx[2] += w*xb[j]*qp; + sumq2[2] += w*qp*qp; + sumqx[1] += w*xb[j]*qm; + sumq2[1] += w*qm*qm; + } + } + best_score = 0; + for (int k = 0; k < 4; ++k) { + if (sumqx[k] > 0 && sumq2[k] > 0 && sumqx[k]*sumqx[k] > best_score*sumq2[k]) { + scale = sumqx[k]/sumq2[k]; best_score = scale*sumqx[k]; best_k = k; } } - if (sumqx_f > 0 && sumq2_f > 0) scale = sumqx_f/sumq2_f; } *the_scale = scale; *the_shift = best_k; @@ -14570,6 +14550,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy const float x_p[3] = {-1 + IQ1M_DELTA, IQ1M_DELTA, 1 + IQ1M_DELTA}; const float x_m[3] = {-1 - IQ1M_DELTA, -IQ1M_DELTA, 1 - IQ1M_DELTA}; const uint8_t masks[4] = {0x00, 0x80, 0x08, 0x88}; + float all_sigma2[QK_K/32]; iq1m_scale_t s; const float * xx; @@ -14582,11 +14563,18 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy float max_scale = 0; const float * xbl = x + QK_K*ibl; - float sumx2 = 0; - for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; - float sigma2 = 2*sumx2/QK_K; + for (int ib = 0; ib < QK_K/32; ++ib) { + const float * xb = xbl + 32*ib; + float sumx2 = 0; + for (int i = 0; i < 32; ++i) sumx2 += xb[i]*xb[i]; + all_sigma2[ib] = 1.5f*sumx2/32; + } + //float sumx2 = 0; + //for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + //float sigma2 = 1.5f*sumx2/QK_K; for (int ib = 0; ib < QK_K/block_size; ++ib) { + float sigma2 = all_sigma2[ib/2]; const float * xb = xbl + block_size*ib; if (quant_weights) { const float * qw = quant_weights + QK_K*ibl + block_size*ib; @@ -14595,12 +14583,21 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i]; } float max = fabsf(xb[0]); - for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i])); + float sumwx = 0; + for (int i = 1; i < block_size; ++i) { + float ax = fabsf(xb[i]); + max = MAX(max, ax); + sumwx += weight[i]*ax; + } if (max < GROUP_MAX_EPS_IQ1_M) { scales[ib] = 0; memset(L, 1, block_size); continue; } + if (sumwx == 0) { + // weight is zero everywhere where xb is not zero => ignore + for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i]; + } int best_k = -1; iq1m_process_1block(xb, weight, L, &scales[ib], index, &best_k, pairs); @@ -14621,6 +14618,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy float id = 1/d; float sumqx_f = 0, sumq2_f = 0; for (int ib = 0; ib < QK_K/block_size; ++ib) { + float sigma2 = all_sigma2[ib/2]; int l = nearest_int(0.5f*(id*scales[ib+0]-1)); l = MAX(0, MIN(7, l)); sc[ib/4] |= (l << 3*(ib%4)); @@ -14645,7 +14643,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy } } if (sumq2_f > 0) d = sumqx_f/sumq2_f; - s.f16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed. + s.f16 = GGML_FP32_TO_FP16(d*1.085f); // 1.085f is another fudge factor. Don't ask me why it is needed. sc[0] |= ((s.u16 & 0x000f) << 12); sc[1] |= ((s.u16 & 0x00f0) << 8); sc[2] |= ((s.u16 & 0x0f00) << 4); From 028e0cfa196ae6adf6cf2f7c98110ef9e71acf91 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 14 Apr 2025 19:41:31 +0200 Subject: [PATCH 085/132] Add ability to hide imatrix details in llama-quantize (#329) Co-authored-by: Iwan Kawrakow --- examples/quantize/quantize.cpp | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 0a91b537..60cf260c 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -142,11 +142,12 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp // [[noreturn]] static void usage(const char * executable) { - printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--attn-q-type] [--attn-k-type] [--attn-v-type] [--attn-qkv-type] [--attn-output-type] [--ffn-gate-type] [--ffn-down-type] [--ffn-up-type] [--keep-split] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n", executable); + printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--hide-imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--attn-q-type] [--attn-k-type] [--attn-v-type] [--attn-qkv-type] [--attn-output-type] [--ffn-gate-type] [--ffn-down-type] [--ffn-up-type] [--keep-split] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n", executable); printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n"); printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n"); printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n"); printf(" --imatrix file_name: use data in file_name as importance matrix for quant optimizations\n"); + printf(" --hide-imatrix: do not store imatrix details in the quantized model\n"); printf(" --include-weights tensor_name: use importance matrix for this/these tensor(s)\n"); printf(" --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n"); printf(" --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor.\n"); @@ -337,6 +338,8 @@ int main(int argc, char ** argv) { std::vector repack_patterns; + bool hide_imatrix = false; + for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) { if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) { params.quantize_output_tensor = false; @@ -429,6 +432,8 @@ int main(int argc, char ** argv) { } else { usage(argv[0]); } + } else if (strcmp(argv[arg_idx], "--hide-imatrix") == 0) { + hide_imatrix = true; } else if (strcmp(argv[arg_idx], "--include-weights") == 0) { if (arg_idx < argc-1) { included_weights.emplace_back(argv[++arg_idx]); @@ -469,7 +474,11 @@ int main(int argc, char ** argv) { llama_model_kv_override kvo; std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_FILE); kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; - strncpy(kvo.val_str, imatrix_file.c_str(), 127); + if (hide_imatrix) { + strncpy(kvo.val_str, "top_secret", 127); + } else { + strncpy(kvo.val_str, imatrix_file.c_str(), 127); + } kvo.val_str[127] = '\0'; kv_overrides.emplace_back(std::move(kvo)); } @@ -477,7 +486,11 @@ int main(int argc, char ** argv) { llama_model_kv_override kvo; std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_DATASET); kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; - strncpy(kvo.val_str, imatrix_dataset.c_str(), 127); + if (hide_imatrix) { + strncpy(kvo.val_str, "top_secret", 127); + } else { + strncpy(kvo.val_str, imatrix_dataset.c_str(), 127); + } kvo.val_str[127] = '\0'; kv_overrides.emplace_back(std::move(kvo)); } @@ -486,7 +499,11 @@ int main(int argc, char ** argv) { llama_model_kv_override kvo; std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES); kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; - kvo.val_i64 = imatrix_data.size(); + if (hide_imatrix) { + kvo.val_i64 = 0; + } else { + kvo.val_i64 = imatrix_data.size(); + } kv_overrides.emplace_back(std::move(kvo)); } @@ -494,7 +511,11 @@ int main(int argc, char ** argv) { llama_model_kv_override kvo; std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS); kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; - kvo.val_i64 = m_last_call; + if (hide_imatrix) { + kvo.val_i64 = 0; + } else { + kvo.val_i64 = m_last_call; + } kv_overrides.emplace_back(std::move(kvo)); } } From 05dbbeaf1489c55533adb5a032e077fab6cd95ad Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 14 Apr 2025 19:43:19 +0200 Subject: [PATCH 086/132] imatrix: collect layer influence statistics (#328) * imatrix: collect layer influence statistics * imatrix: collect layer influence statiscs also for the last layer For the last layer we need to use the input for the output.weight tensor. Last layer(s) tend(s) to be important, so it is useful to also have its influence metric. * imatrix: separate metric for attention and ffn importance * Use stripped tensor name, not src0->name --------- Co-authored-by: Iwan Kawrakow --- examples/imatrix/imatrix.cpp | 171 ++++++++++++++++++++++++++++++++++- 1 file changed, 169 insertions(+), 2 deletions(-) diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index d8a43049..d1693fa5 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -19,6 +19,8 @@ #include #include #include +#include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -49,13 +51,59 @@ public: bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data); void save_imatrix(int ncall = -1) const; bool load_imatrix(const char * file_name); + void set_collect_lsim(bool yes_or_no) { m_collect_lsim = yes_or_no; } + void print_layer_importance(); private: std::unordered_map m_stats; gpt_params m_params; std::mutex m_mutex; int m_last_call = 0; + int m_last_layer = 9999; + int m_last_ffn = -1; std::vector m_src1_data; std::vector m_ids; // the expert ids from ggml_mul_mat_id + std::vector m_last_input; + std::vector m_ffn_input; + std::vector> m_layer_sim; + std::vector> m_attn_sim; + std::vector> m_ffn_sim; + bool m_collect_lsim = false; + + std::optional layer_index(const std::string& name) const { + if (name == m_params.output_tensor_name && m_last_layer < 199) { + return m_last_layer + 1; + } + if (auto pos = name.find("blk."); pos == 0) { + pos += 4; + if (auto pos1 = name.find('.', pos); pos1 != std::string::npos) { + auto index_str = name.substr(pos, pos1 - pos); + std::istringstream str(index_str); + int index; str >> index; + if (!str.fail()) return index; + } + } + return std::nullopt; + } + + static inline double cosine_similarity(int n, const float * x, const float * y) { + double sumxy = 0, sumx2 = 0, sumy2 = 0; + for (int j = 0; j < n; ++j) { + sumxy += x[j]*y[j]; sumx2 += x[j]*x[j]; sumy2 += y[j]*y[j]; + } + double cos_sim = sumx2 > 0 && sumy2 > 0 ? sumxy/sqrt(sumx2*sumy2) : 0; + return cos_sim; + } + + static inline void collect_cos_similarity(int nrow, int n, const float * x, const float * y, std::pair& p) { + for (int row = 0; row < nrow; ++row) { + p.first += cosine_similarity(n, x, y); + p.second += 1; + x += n; + y += n; + } + } + + static void print_layer_importance(const char * msg, const std::vector>& sim); }; // remove any prefix and suffixes from the name @@ -77,6 +125,45 @@ static std::string filter_tensor_name(const char * name) { return wname; } +void IMatrixCollector::print_layer_importance(const char * msg, const std::vector>& sim) { + if (sim.empty()) return; + std::vector> layers; + layers.reserve(sim.size()); + for (int i = 0; i < int(sim.size()); ++i) { + if (sim[i].second > 0) layers.emplace_back(float(std::abs(sim[i].first/sim[i].second)), i); + } + if (layers.empty()) return; + std::sort(layers.begin(), layers.end()); + printf("%s\n", msg); + //printf("======================== sorted layer importances\n"); + int j = 0; + for (auto& p : layers) { + int i = p.second; + printf("%3d: Layer %3d, = %g\n", j++, i, sim[i].first/sim[i].second); + } +} + +void IMatrixCollector::print_layer_importance() { + print_layer_importance("\n======================== sorted layer importances", m_layer_sim); + print_layer_importance("\n======================== sorted attention importances", m_attn_sim); + print_layer_importance("\n======================== sorted ffn importances", m_ffn_sim); + //printf("%s: have %d layers\n", __func__, int(m_layer_sim.size())); + //if (m_layer_sim.empty()) return; + //std::vector> layers; + //layers.reserve(m_layer_sim.size()); + //for (int i = 0; i < int(m_layer_sim.size()); ++i) { + // if (m_layer_sim[i].second > 0) layers.emplace_back(float(std::abs(m_layer_sim[i].first/m_layer_sim[i].second)), i); + //} + //if (layers.empty()) return; + //std::sort(layers.begin(), layers.end()); + //printf("======================== sorted layer importances\n"); + //int j = 0; + //for (auto& p : layers) { + // int i = p.second; + // printf("%3d: Layer %3d, = %g\n", j++, i, m_layer_sim[i].first/m_layer_sim[i].second); + //} +} + bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data) { GGML_UNUSED(user_data); @@ -92,7 +179,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * // why are small batches ignored (<16 tokens)? if (src1->ne[1] < 16 || src1->type != GGML_TYPE_F32) return false; //printf("wname = %s\n", wname.c_str()); - if (!(wname.substr(0, 4) == "blk." || (m_params.process_output && wname == m_params.output_tensor_name))) return false; + if (!(wname.substr(0, 4) == "blk." || ((m_params.process_output || m_collect_lsim) && wname == m_params.output_tensor_name))) return false; return true; } @@ -108,6 +195,33 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * const float * data = is_host ? (const float *) src1->data : m_src1_data.data(); + if (m_collect_lsim) { + if (wname.find(".ffn_") != std::string::npos) { + if (auto index = layer_index(wname); index.has_value() && *index == m_last_layer && *index != m_last_ffn) { + int n = src1->ne[0]; + int nrow = t->op == GGML_OP_MUL_MAT_ID ? src1->ne[2] : src1->ne[1]; + if (t->op == GGML_OP_MUL_MAT_ID) { + GGML_ASSERT(src1->ne[1] == 1); + } + if (m_ffn_input.empty()) { + m_ffn_input.resize(nrow*n); + } else { + if ((int)m_ffn_input.size() != nrow*n) { + printf("Oops, inconsistent ffn size\n"); exit(1); + } + } + std::memcpy(m_ffn_input.data(), data, nrow*n*sizeof(float)); + if (m_ffn_input.size() != m_last_input.size()) { + printf("Oops, inconsistent ffn vs last_input size\n"); exit(1); + } + if (m_attn_sim.size() < *index + 1) m_attn_sim.resize(*index + 1); + auto& p = m_attn_sim[*index]; + collect_cos_similarity(nrow, n, m_ffn_input.data(), m_last_input.data(), p); + m_last_ffn = *index; + } + } + } + // this has been adapted to the new format of storing merged experts in a single 3d tensor // ref: https://github.com/ggerganov/llama.cpp/pull/6387 if (t->op == GGML_OP_MUL_MAT_ID) { @@ -182,6 +296,39 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * } } } else { + if (m_collect_lsim) { + // We only need to do it here and not in the MoE branch above because the first tensor in a layer + // never is a MoE tensor + if (auto index = layer_index(wname); index.has_value()) { + if (*index != m_last_layer) { + if (*index > 0) { + if (m_last_input.size() != src1->ne[0]*src1->ne[1]) { + printf("Oops: different size (%d vs %d). Tensor name was %s, m_last_layer = %d\n", + (int)(src1->ne[0]*src1->ne[1]), (int)m_last_input.size(), src0->name, m_last_layer); + exit(1); + } + if (*index > m_layer_sim.size()) m_layer_sim.resize(*index); + auto& p = m_layer_sim[*index - 1]; + collect_cos_similarity(src1->ne[1], src1->ne[0], m_last_input.data(), (const float *)data, p); + if (*index == m_last_ffn + 1) { + if (*index > m_ffn_sim.size()) m_ffn_sim.resize(*index); + auto& p1 = m_ffn_sim[*index-1]; + collect_cos_similarity(src1->ne[1], src1->ne[0], m_ffn_input.data(), (const float *)data, p1); + } + } + m_last_layer = *index; + if (m_last_input.empty()) { + m_last_input.resize(src1->ne[0]*src1->ne[1]); + } else { + if (m_last_input.size() != src1->ne[0]*src1->ne[1]) { + printf("Oops\n"); exit(1); + } + } + //printf("Copying src1 to m_last_input\n"); + std::memcpy(m_last_input.data(), data, src1->ne[0]*src1->ne[1]*sizeof(float)); + } + } + } auto & e = m_stats[wname]; if (e.values.empty()) { e.values.resize(src1->ne[0], 0); @@ -622,7 +769,25 @@ int main(int argc, char ** argv) { params.logits_all = true; params.verbosity = 1; - if (!gpt_params_parse(argc, argv, params)) { + bool lsim = false; + // + // Do not pollute common with totally imatrix specific arguments as it was done in mainline. + // Instead, parse imatrix specific args here, push unknown args into a new array of args, + // and pass that to gpt_params_parse(). + // + std::vector args; + args.reserve(argc); + args.push_back(argv[0]); + for (int i = 1; i < argc; ++i) { + std::string arg{argv[i]}; + if (arg == "-lsim" || arg == "--layer-similarity") { + lsim = true; + } else { + args.push_back(argv[i]); + } + } + + if (!gpt_params_parse(args.size(), args.data(), params)) { print_usage(argc, argv, params); return 1; } @@ -630,6 +795,7 @@ int main(int argc, char ** argv) { params.n_batch = std::min(params.n_batch, params.n_ctx); g_collector.set_params(params); + g_collector.set_collect_lsim(lsim); for (const auto & in_file : params.in_files) { printf("%s : loading imatrix from '%s'\n", __func__, in_file.c_str()); @@ -680,6 +846,7 @@ int main(int argc, char ** argv) { } g_collector.save_imatrix(); + g_collector.print_layer_importance(); llama_print_timings(ctx); From 1bbb143eb36ce3f6b771fd4c0deb02826c0a7800 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 15 Apr 2025 17:05:31 +0200 Subject: [PATCH 087/132] Allow q8_0 KV cache for head size 256 (#330) * Allow q8_0 KV cache for head size 256 * We need also these --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda.cu | 5 +++++ ggml/src/ggml-cuda/fattn.cu | 4 ++++ .../fattn-vec-f16-instance-hs256-q8_0-q8_0.cu | 5 +++++ .../fattn-vec-f32-instance-hs256-q8_0-q8_0.cu | 5 +++++ 4 files changed, 19 insertions(+) create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-q8_0-q8_0.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-q8_0-q8_0.cu diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 63e16d52..1f62b882 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3578,6 +3578,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons if (op->src[0]->ne[0] == 128) { return true; } + if (op->src[1]->ne[0] == 256 && op->src[2]->ne[0] == 256 && + (op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_Q8_0) && + (op->src[2]->type == GGML_TYPE_F16 || op->src[2]->type == GGML_TYPE_Q8_0)) { + return true; + } if (op->src[1]->ne[0] == 192 && op->src[2]->ne[0] == 128) { return (op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) || (op->src[1]->type == GGML_TYPE_Q8_0 && op->src[2]->type == GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 8c3f4000..55116058 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -248,6 +248,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0,GGML_TYPE_Q8_0) FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL) FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) @@ -265,6 +266,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0,GGML_TYPE_Q8_0) FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL) FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) @@ -347,6 +349,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0,GGML_TYPE_Q8_0) FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) @@ -358,6 +361,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0,GGML_TYPE_Q8_0) FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-q8_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-q8_0-q8_0.cu new file mode 100644 index 00000000..f257f5d8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-q8_0-q8_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-q8_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-q8_0-q8_0.cu new file mode 100644 index 00000000..a0f03f49 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-q8_0-q8_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec-f32.cuh" + +DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); From f7c5a94e756e4add4d531d295ae23493d9857508 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 15 Apr 2025 17:18:50 +0200 Subject: [PATCH 088/132] Better gemm/gemv on AVX2 fr q4_0_r8 (#331) Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_mul_mat.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 9a34270b..78270f5e 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -3620,8 +3620,7 @@ inline __m256i accum_q4_0_quants(const __m256i * v, const int8_t * qs) { _mm256_maddubs_epi16(v[5], _mm256_shuffle_epi32(yh, 0x55))); auto sumi4 = _mm256_add_epi16(_mm256_maddubs_epi16(v[6], _mm256_shuffle_epi32(yh, 0xaa)), _mm256_maddubs_epi16(v[7], _mm256_shuffle_epi32(yh, 0xff))); - auto sumi = _mm256_add_epi32(_mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_add_epi16(sumi1, sumi2)), - _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_add_epi16(sumi3, sumi4))); + auto sumi = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_add_epi16(_mm256_add_epi16(sumi1, sumi2), _mm256_add_epi16(sumi3, sumi4))); #endif return sumi; } From 3bb64d9330d5336d76b036535474d8a4b273373c Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 17 Apr 2025 08:08:40 +0200 Subject: [PATCH 089/132] Better TG performance for GQA models (CPU) (#332) * Slightly better CPU TG performance for GQA * Better CPU FA implementation for TG when GQA * Minor --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml.c | 32 ++++++++++------- ggml/src/iqk/iqk_flash_attn.cpp | 61 +++++++++++++++++++++++++++++++++ ggml/src/iqk/iqk_mul_mat.cpp | 53 ++++++++++++++++++++++++++++ 3 files changed, 134 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 25694fc7..83a48cb6 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -21781,19 +21781,27 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa const struct ggml_tensor * q = node->src[0]; const struct ggml_tensor * k = node->src[1]; if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) { - int nstep_k = k->ne[1]/32; - int gcd_k = simple_gcd(nstep_k, n_tasks); - if (gcd_k > 1) { - int nth_k = n_tasks/gcd_k; - int rk2 = q->ne[2]/k->ne[2]; - int nq_per_thread = (rk2 + nth_k - 1)/nth_k; - size_t size = (Dv + 16)*nq_per_thread*sizeof(float)*n_tasks; - if (ggml_is_quantized(k->type)) { - enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type; - size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]); - size += q->ne[2]*row_size; - } + if (k->ne[2] > 1) { + int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks)); + int nstep_k = k->ne[2]*k->ne[1]/nk; + size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float); + size_t size = nstep_k*result_size; cur = MAX(cur, size); + } else { + int nstep_k = k->ne[1]/32; + int gcd_k = simple_gcd(nstep_k, n_tasks); + if (gcd_k > 1) { + int nth_k = n_tasks/gcd_k; + int rk2 = q->ne[2]/k->ne[2]; + int nq_per_thread = (rk2 + nth_k - 1)/nth_k; + size_t size = (Dv + 16)*nq_per_thread*sizeof(float)*n_tasks; + if (ggml_is_quantized(k->type)) { + enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type; + size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]); + size += q->ne[2]*row_size; + } + cur = MAX(cur, size); + } } } #endif diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp index e302c944..3f1d6dc2 100644 --- a/ggml/src/iqk/iqk_flash_attn.cpp +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -153,6 +153,67 @@ bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, } } + if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) { + int nk = 32 * (nek2*nek1/(32*nth)); + int nkk = (nek1 + nk - 1)/nk; + int nstep_k = nek2*nkk; + auto result_size = (Dv + 16)*rk2*sizeof(float); + //if (ith == 0) printf("rk2 = %d, nek1 = %d, nek2 = %d, nk = %d, nkk = %d, nstep_k = %d\n", (int)rk2, (int)nek1, (int)nek2, nk, nkk, nstep_k); + for (int istep_k = ith; istep_k < nstep_k; istep_k += nth) { + int ik02 = istep_k/nkk; + int ik01 = nk*(istep_k - ik02*nkk); + int this_nk = ik01 + nk <= nek1 ? nk : nek1 - ik01; + if (this_nk <= 0) break; + auto this_result = (float *)((char *)work_buffer + istep_k*result_size); + auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2); + auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2; + auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2; + auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here + if (!iqk_flash_attn_impl(int_type_k, int_type_v, + Dk, Dv, rk2, this_nk, nbq2, stride_k, stride_v, 0, Dv, + this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m, + scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false; + } + + barrier(barrier_data); + + // We have nkk results for each head + for (int iq2 = ith; iq2 < neq2; iq2 += nth) { + // ik02*rk2 + il = iq2 (il = 0...rk2-1) => ik02 = iq2/rk2, il = iq2%rk2; + int ik02 = iq2/rk2; + int il = iq2 - ik02*rk2; + auto Racc = qkv + iq2*nb1/sizeof(float); + std::memset(Racc, 0, Dv*sizeof(float)); + float M = -INFINITY, S = 0; + for (int ikk = 0; ikk < nkk; ++ikk) { + int istep_k = ik02*nkk + ikk; + auto this_result = (float *)((char *)work_buffer + istep_k*result_size); + const float * R = this_result + il*Dv; + const float * Mj = this_result + Dv*rk2; + const float * Sj = Mj + rk2; + if (Mj[il] == -INFINITY) continue; + if (Mj[il] > M) { + if (M == -INFINITY) { + std::memcpy(Racc, R, Dv*sizeof(float)); + S = Sj[il]; + } else { + float c = exp(M - Mj[il]); + S = c*S + Sj[il]; + for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i]; + } + M = Mj[il]; + } else { + float c = exp(Mj[il] - M); + S += c*Sj[il]; + for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i]; + } + } + float norm = S > 0 ? 1/S : 1; + for (int i = 0; i < Dv; ++i) Racc[i] *= norm; + } + return true; + } + // I keep changing my mind what is the best strategy to split the threads when processing // multiple heads. This is my current thinking, the commented out code below was the previous. int ntg = nth/simple_gcd(neq2*neq3, nth); diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 78270f5e..424a65af 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -451,6 +451,51 @@ bool iqk_mul_mat_4d(long Nx, long Ny, long ne00, auto r3 = ne13 / ne03; if (ne13 == 1 && Ny == 1 && r2 > 1) { + if (Nx >= 256 && Nx%32 == 0) { + int nx32 = Nx/32; + int nchunk = nx32*ne02; + if (r2 <= 8) { + MulMat mm; + if (!MulMat::prepare(typeA, typeB, ne00, mm, r2)) return false; + int nx64 = Nx/64; + int nchunk64 = nx64*ne02; + for (int ichunk = ith; ichunk < nchunk64; ichunk += nth) { + int i02 = ichunk/nx64; + int ix = 64*(ichunk - i02*nx64); + DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0}; + mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 64); + } + int ix0 = 64*nx64; + if (ix0 < Nx) { + nx32 -= 2*nx64; + nchunk = nx32*ne02; + for (int ichunk = ith; ichunk < nchunk; ichunk += nth) { + int i02 = ichunk/nx32; + int ix = ix0 + 32*(ichunk - i02*nx32); + DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0}; + mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32); + } + } + //for (int ichunk = ith; ichunk < nchunk; ichunk += nth) { + // int i02 = ichunk/nx32; + // int ix = 32*(ichunk - i02*nx32); + // DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0}; + // mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32); + //} + return true; + } + for (int ichunk = ith; ichunk < nchunk; ichunk += nth) { + int i02 = ichunk/nx32; + int ix = ichunk - i02*nx32; + if (!iqk_mul_mat(32, r2, ne00, + typeA, (const char *)A + 32*ix*strideA + i02*nb02, strideA, + typeB, (const char *)B + i02*r2*nb12, nb12, + C + 32*ix + r2*i02*nb2, nb2, 0, 1)) return false; + + } + return true; + } + //if (ith == 0) printf("Using this: Nx = %d, r2 = %d, ne02 = %d\n", (int)Nx, (int)r2,(int)ne02); int gcd = simple_gcd(ne02, nth); int counter = 0; for (int64_t i12 = 0; i12 < ne02; i12++) { @@ -17153,6 +17198,14 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str FlashAttn fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } + else if (nq1 >= 4) { + FlashAttn fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + } + else if (nq1 >= 2) { + FlashAttn fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + } else { FlashAttn fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); From 93cd77b65501246603061c7ee2801a992e3c6312 Mon Sep 17 00:00:00 2001 From: saood06 Date: Mon, 21 Apr 2025 02:13:46 -0500 Subject: [PATCH 090/132] Fix termux/android build (#336) * Attempt fix * Attempt fix 2 * Attempt fix 3 * Attempt fix 4 * Attempt fix 5 * Attempt fix 6 * Attempt fix 7 * Attempt fix 8 * Attempt fix 9 * Attempt fix 10 * Attempt fix 11 * Attempt fix 12 * Attempt fix 13 --- ggml/src/iqk/iqk_config.h | 14 ++++++++++++++ ggml/src/iqk/iqk_flash_attn.cpp | 9 +++++---- ggml/src/iqk/iqk_mul_mat.cpp | 16 ++++++++-------- ggml/src/iqk/iqk_mul_mat.h | 11 ++++++----- 4 files changed, 33 insertions(+), 17 deletions(-) diff --git a/ggml/src/iqk/iqk_config.h b/ggml/src/iqk/iqk_config.h index 29645f4e..3d8238d7 100644 --- a/ggml/src/iqk/iqk_config.h +++ b/ggml/src/iqk/iqk_config.h @@ -14,6 +14,20 @@ #define IQK_IMPLEMENT #endif +#ifdef GGML_SHARED +# if defined(_WIN32) && !defined(__MINGW32__) +# ifdef GGML_BUILD +# define IQK_API __declspec(dllexport) +# else +# define IQK_API __declspec(dllimport) +# endif +# else +# define IQK_API __attribute__ ((visibility ("default"))) +# endif +#else +# define IQK_API +#endif + #ifdef _MSC_VER #define IQK_NOINLINE __declspec(noinline) #define IQK_ALWAYS_INLINE inline diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp index 3f1d6dc2..639b79af 100644 --- a/ggml/src/iqk/iqk_flash_attn.cpp +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -29,7 +29,7 @@ inline uint32_t simple_gcd(uint32_t a, uint32_t b) { // TODO: get the ggml_type enum here without polution // -bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, +extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, int neq3, int neq2, long nbq3, long nbq2, int nek3, int nek2, long nbk3, long nbk2, int nev3, int nev2, long nbv3, long nbv2, @@ -258,9 +258,10 @@ bool iqk_flash_attn_noalibi([[maybe_unused]] int type_q, [[maybe_unused]] int ty [[maybe_unused]] int nek3, [[maybe_unused]] int nek2, [[maybe_unused]] long nbk3, [[maybe_unused]] long nbk2, [[maybe_unused]] int nev3, [[maybe_unused]] int nev2, [[maybe_unused]] long nbv3, [[maybe_unused]] long nbv2, [[maybe_unused]] int ne2, [[maybe_unused]] int ne1, [[maybe_unused]] long nb1, - [[maybe_unused]] int int_type_k, // type of k - [[maybe_unused]] int int_type_v, // type of v - [[maybe_unused]] int D, // head size + [[maybe_unused]] int type_k, // type of k + [[maybe_unused]] int type_v, // type of v + [[maybe_unused]] int Dk, // K head size + [[maybe_unused]] int Dv, // V head size [[maybe_unused]] int nq, // number of columns in q [[maybe_unused]] int nk, // number of rows in k [[maybe_unused]] int stride_q, // distance between q columns in bytes diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 424a65af..29ae9d53 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -403,7 +403,7 @@ private: } -bool iqk_mul_mat(long Nx, long Ny, long ne00, +extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00, int typeA, const void * A, long strideA, int typeB, const void * B, long strideB, float * C, long stride_C, int ith, int nth) { @@ -440,7 +440,7 @@ inline uint32_t simple_gcd(uint32_t a, uint32_t b) { } } -bool iqk_mul_mat_4d(long Nx, long Ny, long ne00, +extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00, long ne02, long ne03, long ne12, long ne13, long nb02, long nb03, long nb12, long nb13, long nb2, long nb3, int typeA, const void * A, long strideA, @@ -545,7 +545,7 @@ bool iqk_mul_mat_4d(long Nx, long Ny, long ne00, return true; } -bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, +extern "C" IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const void * A, long strideA, int typeB, const void * B, long strideB, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) { @@ -571,7 +571,7 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, return true; } -bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op, +extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op, int typeA, const void * Aup, const void * Agate, long strideA, int typeB, const void * B, long strideB, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) { @@ -17550,11 +17550,11 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k #else // IQK_IMPLEMENT -bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) { +extern "C" IQK_API bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) { return false; } -bool iqk_mul_mat_4d(long /*Nx*/, long /*Ny*/, long /*ne00*/, +extern "C" IQK_API bool iqk_mul_mat_4d(long /*Nx*/, long /*Ny*/, long /*ne00*/, long /*ne02*/, long /*ne03*/, long /*ne12*/, long /*ne13*/, long /*nb02*/, long /*nb03*/, long /*nb12*/, long /*nb13*/, long /*nb2*/, long /*nb3*/, int /*typeA*/, const void * /*A*/, long /*strideA*/, @@ -17563,12 +17563,12 @@ bool iqk_mul_mat_4d(long /*Nx*/, long /*Ny*/, long /*ne00*/, return false; } -bool iqk_mul_mat_moe(long, long, long, int, int, const void *, long, int, const void *, long, float *, long, long, +extern "C" IQK_API bool iqk_mul_mat_moe(long, long, long, int, int, const void *, long, int, const void *, long, float *, long, long, const void *, int, int) { return false; } -bool iqk_moe_fused_up_gate(long /*Nx*/, long /*Ny*/, long /*ne00*/, int /*ne11*/, int /*unary_op*/, +extern "C" IQK_API bool iqk_moe_fused_up_gate(long /*Nx*/, long /*Ny*/, long /*ne00*/, int /*ne11*/, int /*unary_op*/, int /*typeA*/, const void * /*Aup*/, const void * /*Agate*/, long /*strideA*/, int /*typeB*/, const void * /*B*/, long /*strideB*/, float * /*C*/, long /*nb1*/, long /*nb2*/, const void * /*vrow_mapping*/, int /*ith*/, int /*nth*/) { diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index d91c4710..6f44af52 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -7,35 +7,36 @@ #pragma once #include #include +#include "iqk_config.h" #ifdef __cplusplus extern "C" { #endif -bool iqk_mul_mat(long Nx, long Ny, long ne00, +IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00, int typeA, const void * A, long strideA, int typeB, const void * B, long strideB, float * C, long stride_C, int ith, int nth); -bool iqk_mul_mat_4d(long Nx, long Ny, long ne00, +IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00, long ne02, long ne03, long ne12, long ne13, long nb02, long nb03, long nb12, long nb13, long nb2, long nb3, int typeA, const void * A, long strideA, int typeB, const void * B, long strideB, float * C, long stride_C, int ith, int nth); -bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, +IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const void * A, long strideA, int typeB, const void * B, long strideB, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth); -bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op, +IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op, int typeA, const void * Aup, const void * Agate, long strideA, int typeB, const void * B, long strideB, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth); typedef void (*barrier_t) (void *); -bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, +IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias, int neq3, int neq2, long nbq3, long nbq2, int nek3, int nek2, long nbk3, long nbk2, int nev3, int nev2, long nbv3, long nbv2, From cc398007238cbbb064609cbdc9bc4aab03c658d7 Mon Sep 17 00:00:00 2001 From: saood06 Date: Tue, 22 Apr 2025 01:34:13 -0500 Subject: [PATCH 091/132] Add support for bitnet2b_2501 model (#337) * add support for bitnet2b_2501 model * Fixes * Support both model names --------- Co-authored-by: potassiummmm --- convert_hf_to_gguf.py | 1 + gguf-py/gguf/constants.py | 24 +++ gguf-py/gguf/tensor_mapping.py | 5 + src/llama.cpp | 301 ++++++++++++++++++++++++++++++++- 4 files changed, 330 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 1ee82724..a6ab09c0 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1598,6 +1598,7 @@ class LlamaModel(Model): @Model.register("BitnetForCausalLM") +@Model.register("BitNetForCausalLM") class BitnetModel(Model): model_arch = gguf.MODEL_ARCH.BITNET diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 93c614d6..22cde145 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -219,6 +219,7 @@ class MODEL_ARCH(IntEnum): DEEPSEEK2 = auto() CHATGLM = auto() BITNET = auto() + BITNET_25 = auto() T5 = auto() T5ENCODER = auto() JAIS = auto() @@ -351,6 +352,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.DEEPSEEK2: "deepseek2", MODEL_ARCH.CHATGLM: "chatglm", MODEL_ARCH.BITNET: "bitnet", + MODEL_ARCH.BITNET_25: "bitnet-25", MODEL_ARCH.T5: "t5", MODEL_ARCH.T5ENCODER: "t5encoder", MODEL_ARCH.JAIS: "jais", @@ -1019,6 +1021,28 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_SUB_NORM, MODEL_TENSOR.FFN_SUB_NORM, ], + MODEL_ARCH.BITNET_25: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.ATTN_SUB_NORM, + MODEL_TENSOR.FFN_SUB_NORM, + ], MODEL_ARCH.T5: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index e8725426..1dea6a82 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -131,6 +131,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.qkv_proj", # phi3 "encoder.layers.{bid}.self_attention.query_key_value", # chatglm "transformer.layers.{bid}.attn.qkv_proj", # openelm + "layers.{bid}.attention.wqkv", ), # Attention query @@ -464,10 +465,14 @@ class TensorNameMap: MODEL_TENSOR.ATTN_SUB_NORM: ( "model.layers.{bid}.self_attn.inner_attn_ln", # bitnet + "layers.{bid}.attention.attn_sub_norm", # bitnet + "model.layers.{bid}.self_attn.attn_sub_norm", ), MODEL_TENSOR.FFN_SUB_NORM: ( "model.layers.{bid}.mlp.ffn_layernorm", # bitnet + "layers.{bid}.feed_forward.ffn_sub_norm", # bitnet + "model.layers.{bid}.mlp.ffn_sub_norm", ), MODEL_TENSOR.DEC_ATTN_NORM: ( diff --git a/src/llama.cpp b/src/llama.cpp index 5340d847..dcac3833 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -223,6 +223,7 @@ enum llm_arch { LLM_ARCH_DEEPSEEK2, LLM_ARCH_CHATGLM, LLM_ARCH_BITNET, + LLM_ARCH_BITNET_25, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, LLM_ARCH_JAIS, @@ -272,6 +273,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_BITNET, "bitnet" }, + { LLM_ARCH_BITNET_25, "bitnet-25" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_JAIS, "jais" }, @@ -1323,6 +1325,34 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, }, }, + { + LLM_ARCH_BITNET_25, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, + { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, + }, + }, { LLM_ARCH_T5, { @@ -1468,6 +1498,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_GIGACHAT, LLM_CHAT_TEMPLATE_MEGREZ, LLM_CHAT_TEMPLATE_LLAMA4, + LLM_CHAT_TEMPLATE_BITNET, LLM_CHAT_TEMPLATE_UNKNOWN, }; @@ -1504,6 +1535,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, { "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, + { "bitnet", LLM_CHAT_TEMPLATE_BITNET }, }; @@ -5120,7 +5152,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); - if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) { + if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON || model.arch == LLM_ARCH_BITNET_25) { if (hparams.n_rot != hparams.n_embd_head_k) { throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); } @@ -5690,6 +5722,15 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_BITNET_25: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 30: model.type = e_model::MODEL_2B; break; // bitnet2b_2501 + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_T5: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -8107,6 +8148,90 @@ static bool llm_load_tensors( layer.ffn_up_scale = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); } } break; + case LLM_ARCH_BITNET_25: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); + layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); + + // optional bias tensors + layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + + if (n_expert == 0) { + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + + // optional MLP bias + layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + + layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); + if (layer.ffn_gate_exps) { + layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}); + layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + } else { + // merge split expert into a single tensor for compatibility with older models + // requires disabling mmap + use_mmap_buffer = false; + + ggml_type type_gate = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, 0).c_str())->type; + ggml_type type_down = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, 0).c_str())->type; + ggml_type type_up = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, 0).c_str())->type; + + layer.ffn_gate_exps = ggml_new_tensor_3d(ctx_split, type_gate, n_embd, n_ff, n_expert); + layer.ffn_down_exps = ggml_new_tensor_3d(ctx_split, type_down, n_ff, n_embd, n_expert); + layer.ffn_up_exps = ggml_new_tensor_3d(ctx_split, type_up, n_embd, n_ff, n_expert); + + ggml_set_name(layer.ffn_gate_exps, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i).c_str()); + ggml_set_name(layer.ffn_down_exps, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i).c_str()); + ggml_set_name(layer.ffn_up_exps, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i).c_str()); + + for (uint32_t x = 0; x < n_expert; ++x) { + // the individual experts are loaded into a view of the merged tensor + ml.create_tensor_as_view(ctx_split, layer.ffn_gate_exps, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), { n_embd, n_ff }, layer.ffn_gate_exps->nb[2]*x); + ml.create_tensor_as_view(ctx_split, layer.ffn_down_exps, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), { n_ff, n_embd }, layer.ffn_down_exps->nb[2]*x); + ml.create_tensor_as_view(ctx_split, layer.ffn_up_exps, tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, x), { n_embd, n_ff }, layer.ffn_up_exps->nb[2]*x); + } + } + } + } + } break; case LLM_ARCH_T5: { const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts; @@ -14731,6 +14856,156 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_bitnet_25() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + struct ggml_tensor * rope_factors = build_rope_factors(il); + // printf("%f\n\n\n\n",((float*)rope_factors->data)[1]); + + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + NULL, NULL, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].attn_sub_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_sub_norm", il); + + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); + if (model.layers[il].wo_scale) { + cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); + } + if (model.layers[il].bo) { + cur = ggml_add(ctx0, cur, model.layers[il].bo); + } + cb(cur, "attn_o_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + // n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_scale, + model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale, + NULL, NULL, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].ffn_sub_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_sub_norm", il); + + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_down, cur); + if (model.layers[il].ffn_down_scale) { + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); + } + cb(cur, "ffn_down", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.tok_embd, cur); + + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + struct ggml_cgraph * build_t5_encoder() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -15527,6 +15802,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_bitnet(); } break; + case LLM_ARCH_BITNET_25: + { + result = llm.build_bitnet_25(); + } break; case LLM_ARCH_T5: { if (lctx.is_encoding) { @@ -19210,6 +19489,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_STABLELM: case LLM_ARCH_BITNET: + case LLM_ARCH_BITNET_25: case LLM_ARCH_QWEN: case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2MOE: @@ -21403,6 +21683,25 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|header_start|>assistant<|header_end|>\n\n"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_BITNET) { + // bitnet-25 + std::string system_prompt = ""; + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "System: "; + ss << message->content; + } else if (role == "user") { + ss << "User: "; + if (!system_prompt.empty()) { + ss << system_prompt; + system_prompt = ""; + } + ss << message->content << "<|eot_id|>Assistant: "; + } else { + ss << message->content; + } + } } else { // template not supported return -1; From 9dac3edf2f3d9b52f83e1a9461e45523bb33314e Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 22 Apr 2025 08:46:31 +0200 Subject: [PATCH 092/132] BitNet adjustments (#338) Co-authored-by: Iwan Kawrakow --- src/llama.cpp | 60 +++++++++++++++++++++++++-------------------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index dcac3833..d0fd9c48 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8150,16 +8150,16 @@ static bool llm_load_tensors( } break; case LLM_ARCH_BITNET_25: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -8169,42 +8169,42 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); - layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); - - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); + layer.attn_sub_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); + layer.ffn_sub_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); + + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); // optional bias tensors - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.rope_freqs = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); if (n_expert == 0) { - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); // optional MLP bias - layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_gate_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); } else { - layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); - layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); if (layer.ffn_gate_exps) { - layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}); - layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}); + layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); } else { // merge split expert into a single tensor for compatibility with older models // requires disabling mmap @@ -14931,7 +14931,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, lctx, kv_self, gf, NULL, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); - + cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_sub_norm, NULL, LLM_NORM_RMS, cb, il); @@ -14997,7 +14997,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = llm_build_lora_mm(lctx, ctx0, model.tok_embd, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); From 222a1957430cf531a13c358f7ed54b7a4d96c26a Mon Sep 17 00:00:00 2001 From: saood06 Date: Thu, 24 Apr 2025 00:34:10 -0500 Subject: [PATCH 093/132] Update gguf-py constants (#298) * Update GGMLQuantizationType * Update LlamaFileType * Update GGML_QUANT_SIZES --- gguf-py/gguf/constants.py | 344 +++++++++++++++++++++++++------------- 1 file changed, 227 insertions(+), 117 deletions(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 22cde145..6e03f9a4 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1171,47 +1171,86 @@ class PoolingType(IntEnum): class GGMLQuantizationType(IntEnum): - F32 = 0 - F16 = 1 - Q4_0 = 2 - Q4_1 = 3 - Q5_0 = 6 - Q5_1 = 7 - Q8_0 = 8 - Q8_1 = 9 - Q2_K = 10 - Q3_K = 11 - Q4_K = 12 - Q5_K = 13 - Q6_K = 14 - Q8_K = 15 - IQ2_XXS = 16 - IQ2_XS = 17 - IQ3_XXS = 18 - IQ1_S = 19 - IQ4_NL = 20 - IQ3_S = 21 - IQ2_S = 22 - IQ4_XS = 23 - I8 = 24 - I16 = 25 - I32 = 26 - I64 = 27 - F64 = 28 - IQ1_M = 29 - BF16 = 30 - Q4_0_4_4 = 31 - Q4_0_4_8 = 32 - Q4_0_8_8 = 33 - IQ1_BN = 34, - IQ2_BN = 35, - Q8_K64 = 36, - IQ2_K = 37, - IQ3_K = 38, - IQ4_K = 39, - IQ5_K = 40, - IQ6_K = 41, - IQ2_TN = 42, + F32 = 0 + F16 = 1 + Q4_0 = 2 + Q4_1 = 3 + Q5_0 = 6 + Q5_1 = 7 + Q8_0 = 8 + Q8_1 = 9 + Q2_K = 10 + Q3_K = 11 + Q4_K = 12 + Q5_K = 13 + Q6_K = 14 + Q8_K = 15 + IQ2_XXS = 16 + IQ2_XS = 17 + IQ3_XXS = 18 + IQ1_S = 19 + IQ4_NL = 20 + IQ3_S = 21 + IQ2_S = 22 + IQ4_XS = 23 + I8 = 24 + I16 = 25 + I32 = 26 + I64 = 27 + F64 = 28 + IQ1_M = 29 + BF16 = 30 + Q4_0_4_4 = 31 + Q4_0_4_8 = 32 + Q4_0_8_8 = 33 + I2_S = 36 + Q8_0_X4 = 97 + Q8_1_X4 = 98 + Q8_2_X4 = 99 + Q6_0 = 133 + IQ1_BN = 134 + IQ2_BN = 135 + Q8_K64 = 136 + IQ2_K = 137 + IQ3_K = 138 + IQ4_K = 139 + IQ5_K = 140 + IQ6_K = 141 + IQ4_KS = 144 + IQ2_KS = 145 + IQ4_KSS = 146 + Q8_K16 = 147 + Q8_K32 = 148 + Q8_KR8 = 149 + Q8_K128 = 150 + Q8_KV = 151 + Q4_0_R8 = 202 + Q5_0_R4 = 206 + Q8_0_R8 = 208 + Q2_K_R4 = 210 + Q3_K_R4 = 211 + Q4_K_R4 = 212 + Q5_K_R4 = 213 + Q6_K_R4 = 214 + IQ2_XXS_R4= 216 + IQ2_XS_R4 = 217 + IQ3_XXS_R4= 218 + IQ1_S_R4 = 219 + IQ4_NL_R4 = 220 + IQ3_S_R4 = 221 + IQ2_S_R4 = 222 + IQ4_XS_R8 = 223 + IQ1_M_R4 = 229 + BF16_R16 = 230 + Q6_0_R4 = 233 + IQ2_BN_R4 = 335 + IQ2_K_R4 = 337 + IQ3_K_R4 = 338 + IQ4_K_R4 = 339 + IQ5_K_R4 = 340 + IQ4_KS_R4 = 344 + Q8_KV_R8 = 398 + Q8_K_R8 = 399 class ExpertGatingFuncType(IntEnum): @@ -1225,50 +1264,71 @@ class ExpertGatingFuncType(IntEnum): # from llama_ftype in llama.h # ALL VALUES SHOULD BE THE SAME HERE AS THEY ARE OVER THERE. class LlamaFileType(IntEnum): - ALL_F32 = 0 - MOSTLY_F16 = 1 # except 1d tensors - MOSTLY_Q4_0 = 2 # except 1d tensors - MOSTLY_Q4_1 = 3 # except 1d tensors - # MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16 - # MOSTLY_Q4_2 = 5 # support has been removed - # MOSTLY_Q4_3 = 6 # support has been removed - MOSTLY_Q8_0 = 7 # except 1d tensors - MOSTLY_Q5_0 = 8 # except 1d tensors - MOSTLY_Q5_1 = 9 # except 1d tensors - MOSTLY_Q2_K = 10 # except 1d tensors - MOSTLY_Q3_K_S = 11 # except 1d tensors - MOSTLY_Q3_K_M = 12 # except 1d tensors - MOSTLY_Q3_K_L = 13 # except 1d tensors - MOSTLY_Q4_K_S = 14 # except 1d tensors - MOSTLY_Q4_K_M = 15 # except 1d tensors - MOSTLY_Q5_K_S = 16 # except 1d tensors - MOSTLY_Q5_K_M = 17 # except 1d tensors - MOSTLY_Q6_K = 18 # except 1d tensors - MOSTLY_IQ2_XXS = 19 # except 1d tensors - MOSTLY_IQ2_XS = 20 # except 1d tensors - MOSTLY_Q2_K_S = 21 # except 1d tensors - MOSTLY_IQ3_XS = 22 # except 1d tensors - MOSTLY_IQ3_XXS = 23 # except 1d tensors - MOSTLY_IQ1_S = 24 # except 1d tensors - MOSTLY_IQ4_NL = 25 # except 1d tensors - MOSTLY_IQ3_S = 26 # except 1d tensors - MOSTLY_IQ3_M = 27 # except 1d tensors - MOSTLY_IQ2_S = 28 # except 1d tensors - MOSTLY_IQ2_M = 29 # except 1d tensors - MOSTLY_IQ4_XS = 30 # except 1d tensors - MOSTLY_IQ1_M = 31 # except 1d tensors - MOSTLY_BF16 = 32 # except 1d tensors - MOSTLY_Q4_0_4_4 = 33 # except 1d tensors - MOSTLY_Q4_0_4_8 = 34 # except 1d tensors - MOSTLY_Q4_0_8_8 = 35 # except 1d tensors - MOSTLY_IQ1_BN = 36, # except 1d tensors - MOSTLY_IQ2_BN = 37, # except 1d tensors - MOSTLY_IQ2_K = 38, # except 1d tensors - MOSTLY_IQ3_K = 39, # except 1d tensors - MOSTLY_IQ4_K = 40, # except 1d tensors - MOSTLY_IQ5_K = 41, # except 1d tensors - MOSTLY_IQ6_K = 42, # except 1d tensors - MOSTLY_IQ2_TN = 43, # except 1d tensors + ALL_F32 = 0 + MOSTLY_F16 = 1 #except 1d tensors + MOSTLY_Q4_0 = 2 #except 1d tensors + MOSTLY_Q4_1 = 3 #except 1d tensors + MOSTLY_Q4_1_SOME_F16 = 4 #tok_embeddings.weight and output.weight are F16 + MOSTLY_Q8_0 = 7 #except 1d tensors + MOSTLY_Q5_0 = 8 #except 1d tensors + MOSTLY_Q5_1 = 9 #except 1d tensors + MOSTLY_Q2_K = 10 #except 1d tensors + MOSTLY_Q3_K = 11 #except 1d tensors + MOSTLY_Q4_K = 12 #except 1d tensors + MOSTLY_Q5_K = 13 #except 1d tensors + MOSTLY_Q6_K = 14 #except 1d tensors + MOSTLY_IQ2_XXS = 15 #except 1d tensors + MOSTLY_IQ2_XS = 16 #except 1d tensors + MOSTLY_IQ3_XXS = 17 #except 1d tensors + MOSTLY_IQ1_S = 18 #except 1d tensors + MOSTLY_IQ4_NL = 19 #except 1d tensors + MOSTLY_IQ3_S = 20 #except 1d tensors + MOSTLY_IQ2_S = 21 #except 1d tensors + MOSTLY_IQ4_XS = 22 #except 1d tensors + MOSTLY_IQ1_M = 23 #except 1d tensors + MOSTLY_BF16 = 24 #except 1d tensors + MOSTLY_Q4_0_4_4 = 25 #except 1d tensors + MOSTLY_Q4_0_4_8 = 26 #except 1d tensors + MOSTLY_Q4_0_8_8 = 27 #except 1d tensors + MOSTLY_Q6_0 = 127 #except 1d tensors + MOSTLY_IQ1_BN = 128 #except 1d tensors + MOSTLY_IQ2_BN = 129 #except 1d tensors + MOSTLY_IQ2_K = 130 #except 1d tensors + MOSTLY_IQ3_K = 131 #except 1d tensors + MOSTLY_IQ4_K = 132 #except 1d tensors + MOSTLY_IQ5_K = 133 #except 1d tensors + MOSTLY_IQ6_K = 134 #except 1d tensors + MOSTLY_IQ4_KS = 137 #except 1d tensors + MOSTLY_IQ2_KS = 138 #except 1d tensors + MOSTLY_IQ4_KSS = 139 #except 1d tensors + MOSTLY_Q8_KV = 140 #except 1d tensors + MOSTLY_Q4_0_R8 = 202 #except 1d tensors + MOSTLY_Q8_0_R8 = 207 #except 1d tensors + MOSTLY_Q5_0_R4 = 208 #except 1d tensors + MOSTLY_Q2_K_R4 = 210 #except 1d tensors + MOSTLY_Q3_K_R4 = 211 #except 1d tensors + MOSTLY_Q4_K_R4 = 212 #except 1d tensors + MOSTLY_Q5_K_R4 = 213 #except 1d tensors + MOSTLY_Q6_K_R4 = 214 #except 1d tensors + MOSTLY_IQ2_XXS_R4 = 215 #except 1d tensors + MOSTLY_IQ2_XS_R4 = 216 #except 1d tensors + MOSTLY_IQ3_XXS_R4 = 217 #except 1d tensors + MOSTLY_IQ1_S_R4 = 218 #except 1d tensors + MOSTLY_IQ4_NL_R4 = 219 #except 1d tensors + MOSTLY_IQ3_S_R4 = 220 #except 1d tensors + MOSTLY_IQ2_S_R4 = 221 #except 1d tensors + MOSTLY_IQ4_XS_R8 = 222 #except 1d tensors + MOSTLY_IQ1_M_R4 = 223 #except 1d tensors + MOSTLY_BF16_R16 = 224 #except 1d tensors + MOSTLY_Q6_0_R4 = 227 #except 1d tensors + MOSTLY_IQ2_BN_R4 = 329 #except 1d tensors + MOSTLY_IQ2_K_R4 = 330 #except 1d tensors + MOSTLY_IQ3_K_R4 = 331 #except 1d tensors + MOSTLY_IQ4_K_R4 = 332 #except 1d tensors + MOSTLY_IQ5_K_R4 = 333 #except 1d tensors + MOSTLY_IQ4_KS_R4 = 337 #except 1d tensors + MOSTLY_Q8_KV_R8 = 398 #except 1d tensors + MOSTLY_Q8_K_R8 = 399 #except 1d tensors GUESSED = 1024 # not specified in the model file @@ -1313,39 +1373,89 @@ class GGUFValueType(IntEnum): # Items here are (block size, type size) QK_K = 256 + +#Values generated programatically GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = { - GGMLQuantizationType.F32: (1, 4), - GGMLQuantizationType.F16: (1, 2), - GGMLQuantizationType.Q4_0: (32, 2 + 16), - GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16), - GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16), - GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16), - GGMLQuantizationType.Q8_0: (32, 2 + 32), - GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32), - GGMLQuantizationType.Q2_K: (256, 2 + 2 + QK_K // 16 + QK_K // 4), - GGMLQuantizationType.Q3_K: (256, 2 + QK_K // 4 + QK_K // 8 + 12), - GGMLQuantizationType.Q4_K: (256, 2 + 2 + QK_K // 2 + 12), - GGMLQuantizationType.Q5_K: (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12), - GGMLQuantizationType.Q6_K: (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16), - GGMLQuantizationType.Q8_K: (256, 4 + QK_K + QK_K // 8), - GGMLQuantizationType.IQ2_XXS: (256, 2 + QK_K // 4), - GGMLQuantizationType.IQ2_XS: (256, 2 + QK_K // 4 + QK_K // 32), - GGMLQuantizationType.IQ3_XXS: (256, 2 + QK_K // 4 + QK_K // 8), - GGMLQuantizationType.IQ1_S: (256, 2 + QK_K // 8 + QK_K // 16), - GGMLQuantizationType.IQ4_NL: (32, 2 + 16), - GGMLQuantizationType.IQ3_S: (256, 2 + QK_K // 4 + QK_K // 8 + QK_K // 32 + 4), - GGMLQuantizationType.IQ2_S: (256, 2 + QK_K // 4 + QK_K // 16), - GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + QK_K // 2 + QK_K // 64), - GGMLQuantizationType.I8: (1, 1), - GGMLQuantizationType.I16: (1, 2), - GGMLQuantizationType.I32: (1, 4), - GGMLQuantizationType.I64: (1, 8), - GGMLQuantizationType.F64: (1, 8), - GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32), - GGMLQuantizationType.BF16: (1, 2), - GGMLQuantizationType.Q4_0_4_4:(32, 2 + 16), - GGMLQuantizationType.Q4_0_4_8:(32, 2 + 16), - GGMLQuantizationType.Q4_0_8_8:(32, 2 + 16), + GGMLQuantizationType.F32 : ( 1, 4), + GGMLQuantizationType.F16 : ( 1, 2), + GGMLQuantizationType.Q4_0 : ( 32, 18), + GGMLQuantizationType.Q4_1 : ( 32, 20), + GGMLQuantizationType.Q5_0 : ( 32, 22), + GGMLQuantizationType.Q5_1 : ( 32, 24), + GGMLQuantizationType.Q8_0 : ( 32, 34), + GGMLQuantizationType.Q8_1 : ( 32, 36), + GGMLQuantizationType.Q2_K : ( 256, 84), + GGMLQuantizationType.Q3_K : ( 256, 110), + GGMLQuantizationType.Q4_K : ( 256, 144), + GGMLQuantizationType.Q5_K : ( 256, 176), + GGMLQuantizationType.Q6_K : ( 256, 210), + GGMLQuantizationType.Q8_K : ( 256, 292), + GGMLQuantizationType.IQ2_XXS : ( 256, 66), + GGMLQuantizationType.IQ2_XS : ( 256, 74), + GGMLQuantizationType.IQ3_XXS : ( 256, 98), + GGMLQuantizationType.IQ1_S : ( 256, 50), + GGMLQuantizationType.IQ4_NL : ( 32, 18), + GGMLQuantizationType.IQ3_S : ( 256, 110), + GGMLQuantizationType.IQ2_S : ( 256, 82), + GGMLQuantizationType.IQ4_XS : ( 256, 136), + GGMLQuantizationType.I8 : ( 1, 1), + GGMLQuantizationType.I16 : ( 1, 2), + GGMLQuantizationType.I32 : ( 1, 4), + GGMLQuantizationType.I64 : ( 1, 8), + GGMLQuantizationType.F64 : ( 1, 8), + GGMLQuantizationType.IQ1_M : ( 256, 56), + GGMLQuantizationType.BF16 : ( 1, 2), + GGMLQuantizationType.Q4_0_4_4 : ( 32, 18), + GGMLQuantizationType.Q4_0_4_8 : ( 32, 18), + GGMLQuantizationType.Q4_0_8_8 : ( 32, 18), + GGMLQuantizationType.I2_S : ( 1, 1), + GGMLQuantizationType.Q8_0_X4 : ( 32, 34), + GGMLQuantizationType.Q8_1_X4 : ( 32, 36), + GGMLQuantizationType.Q8_2_X4 : ( 32, 36), + GGMLQuantizationType.Q6_0 : ( 32, 26), + GGMLQuantizationType.IQ1_BN : ( 64, 13), + GGMLQuantizationType.IQ2_BN : ( 64, 16), + GGMLQuantizationType.Q8_K64 : ( 64, 68), + GGMLQuantizationType.IQ2_K : ( 256, 76), + GGMLQuantizationType.IQ3_K : ( 256, 110), + GGMLQuantizationType.IQ4_K : ( 256, 144), + GGMLQuantizationType.IQ5_K : ( 256, 176), + GGMLQuantizationType.IQ6_K : ( 256, 212), + GGMLQuantizationType.IQ4_KS : ( 256, 136), + GGMLQuantizationType.IQ2_KS : ( 256, 70), + GGMLQuantizationType.IQ4_KSS : ( 256, 128), + GGMLQuantizationType.Q8_K16 : ( 64, 64), + GGMLQuantizationType.Q8_K32 : ( 256, 292), + GGMLQuantizationType.Q8_KR8 : ( 256, 292), + GGMLQuantizationType.Q8_K128 : ( 128, 140), + GGMLQuantizationType.Q8_KV : ( 32, 32), + GGMLQuantizationType.Q4_0_R8 : ( 32, 18), + GGMLQuantizationType.Q5_0_R4 : ( 32, 22), + GGMLQuantizationType.Q8_0_R8 : ( 32, 34), + GGMLQuantizationType.Q2_K_R4 : ( 256, 84), + GGMLQuantizationType.Q3_K_R4 : ( 256, 110), + GGMLQuantizationType.Q4_K_R4 : ( 256, 144), + GGMLQuantizationType.Q5_K_R4 : ( 256, 176), + GGMLQuantizationType.Q6_K_R4 : ( 256, 210), + GGMLQuantizationType.IQ2_XXS_R4 : ( 256, 66), + GGMLQuantizationType.IQ2_XS_R4 : ( 256, 74), + GGMLQuantizationType.IQ3_XXS_R4 : ( 256, 98), + GGMLQuantizationType.IQ1_S_R4 : ( 32, 6), + GGMLQuantizationType.IQ4_NL_R4 : ( 32, 18), + GGMLQuantizationType.IQ3_S_R4 : ( 256, 110), + GGMLQuantizationType.IQ2_S_R4 : ( 256, 82), + GGMLQuantizationType.IQ4_XS_R8 : ( 256, 136), + GGMLQuantizationType.IQ1_M_R4 : ( 32, 7), + GGMLQuantizationType.BF16_R16 : ( 1, 2), + GGMLQuantizationType.Q6_0_R4 : ( 32, 26), + GGMLQuantizationType.IQ2_BN_R4 : ( 64, 16), + GGMLQuantizationType.IQ2_K_R4 : ( 256, 76), + GGMLQuantizationType.IQ3_K_R4 : ( 256, 110), + GGMLQuantizationType.IQ4_K_R4 : ( 256, 144), + GGMLQuantizationType.IQ5_K_R4 : ( 256, 176), + GGMLQuantizationType.IQ4_KS_R4 : ( 256, 136), + GGMLQuantizationType.Q8_KV_R8 : ( 32, 32), + GGMLQuantizationType.Q8_K_R8 : ( 256, 258), } From c9eec1729fe95a5fcfd4ce47df440c2445abb17e Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 24 Apr 2025 17:37:12 +0200 Subject: [PATCH 094/132] cuda: use switch in constexpr funcs (#343) Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda/mmq.cuh | 83 +++++++++++++++++++------------------ ggml/src/ggml-cuda/mmvq.cu | 84 ++++++++++++++++++++------------------ 2 files changed, 87 insertions(+), 80 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 88e023a1..753848d7 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -160,26 +160,27 @@ static constexpr __device__ int get_mmq_y_device() { #define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { - return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 : - type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 : - type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 : - type == GGML_TYPE_Q6_0 ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K : - type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K : - type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K : - type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K : - type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K : - type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 : - type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 : - type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 : - tile_x_sizes{0, 0, 0}; + switch (type) { + case GGML_TYPE_Q4_1 : return MMQ_DP4A_TXS_Q4_1; + case GGML_TYPE_Q5_0 : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_Q5_1 : return MMQ_DP4A_TXS_Q8_1; + case GGML_TYPE_Q6_0 : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_Q8_0 : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_Q2_K : return MMQ_DP4A_TXS_Q2_K; + case GGML_TYPE_Q3_K : return MMQ_DP4A_TXS_Q3_K; + case GGML_TYPE_Q4_K : return MMQ_DP4A_TXS_Q4_K; + case GGML_TYPE_Q5_K : return MMQ_DP4A_TXS_Q5_K; + case GGML_TYPE_Q6_K : return MMQ_DP4A_TXS_Q6_K; + case GGML_TYPE_IQ2_XXS : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ2_XS : return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ2_S : return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ3_XXS : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ3_S : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ1_S : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_XS : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_NL : return MMQ_DP4A_TXS_Q8_0; + default : return tile_x_sizes{0, 0, 0}; + } } #define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4) @@ -195,26 +196,28 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { - return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 : - type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 : - type == GGML_TYPE_Q6_0 ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K : - type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K : - type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 : - type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 : - type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : - type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K : - type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K : - type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 : - 0; + switch (type) { + case GGML_TYPE_Q4_0 : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q4_1 : return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q5_0 : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q5_1 : return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q6_0 : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q8_0 : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q2_K : return MMQ_MMA_TILE_X_K_Q2_K; + case GGML_TYPE_Q3_K : return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_Q4_K : return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q5_K : return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q6_K : return MMQ_MMA_TILE_X_K_Q6_K; + case GGML_TYPE_IQ2_XXS : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ2_XS : return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ2_S : return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ3_XXS : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ3_S : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ1_S : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_XS : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_NL : return MMQ_MMA_TILE_X_K_Q8_0; + default : return 0; + } } #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 69ccef1d..3d991b4d 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -12,49 +12,53 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { - return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 : - type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 : - type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 : - type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 : - type == GGML_TYPE_Q6_0 ? vec_dot_q6_0_q8_1 : - type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 : - type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 : - type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 : - type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 : - type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 : - type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 : - type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 : - type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 : - type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 : - type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 : - type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 : - type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 : - type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 : - type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 : - type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 : - nullptr; + switch (type) { + case GGML_TYPE_Q4_0 : return vec_dot_q4_0_q8_1; + case GGML_TYPE_Q4_1 : return vec_dot_q4_1_q8_1; + case GGML_TYPE_Q5_0 : return vec_dot_q5_0_q8_1; + case GGML_TYPE_Q5_1 : return vec_dot_q5_1_q8_1; + case GGML_TYPE_Q6_0 : return vec_dot_q6_0_q8_1; + case GGML_TYPE_Q8_0 : return vec_dot_q8_0_q8_1; + case GGML_TYPE_Q2_K : return vec_dot_q2_K_q8_1; + case GGML_TYPE_Q3_K : return vec_dot_q3_K_q8_1; + case GGML_TYPE_Q4_K : return vec_dot_q4_K_q8_1; + case GGML_TYPE_Q5_K : return vec_dot_q5_K_q8_1; + case GGML_TYPE_Q6_K : return vec_dot_q6_K_q8_1; + case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1; + case GGML_TYPE_IQ2_XS : return vec_dot_iq2_xs_q8_1; + case GGML_TYPE_IQ2_S : return vec_dot_iq2_s_q8_1; + case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1; + case GGML_TYPE_IQ1_S : return vec_dot_iq1_s_q8_1; + case GGML_TYPE_IQ1_M : return vec_dot_iq1_m_q8_1; + case GGML_TYPE_IQ4_NL : return vec_dot_iq4_nl_q8_1; + case GGML_TYPE_IQ4_XS : return vec_dot_iq4_xs_q8_1; + case GGML_TYPE_IQ3_S : return vec_dot_iq3_s_q8_1; + default : return nullptr; + } } static constexpr __device__ int get_vdr_mmvq(ggml_type type) { - return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ : - type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ : - type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ : - type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ : - type == GGML_TYPE_Q6_0 ? VDR_Q6_0_Q8_1_MMVQ : - type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ : - type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ : - type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ : - type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ : - type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ : - type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ : - type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ : - type == GGML_TYPE_IQ2_XS ? VDR_IQ2_XS_Q8_1_MMVQ : - type == GGML_TYPE_IQ2_S ? VDR_IQ2_S_Q8_1_MMVQ : - type == GGML_TYPE_IQ3_XXS ? VDR_IQ3_XXS_Q8_1_MMVQ : - type == GGML_TYPE_IQ3_S ? VDR_IQ3_S_Q8_1_MMVQ : - type == GGML_TYPE_IQ4_NL ? VDR_IQ4_NL_Q8_1_MMVQ : - type == GGML_TYPE_IQ4_XS ? VDR_IQ4_XS_Q8_1_MMVQ : - 1; + switch (type) { + case GGML_TYPE_Q4_0 : return VDR_Q4_0_Q8_1_MMVQ; + case GGML_TYPE_Q4_1 : return VDR_Q4_1_Q8_1_MMVQ; + case GGML_TYPE_Q5_0 : return VDR_Q5_0_Q8_1_MMVQ; + case GGML_TYPE_Q5_1 : return VDR_Q5_1_Q8_1_MMVQ; + case GGML_TYPE_Q6_0 : return VDR_Q6_0_Q8_1_MMVQ; + case GGML_TYPE_Q8_0 : return VDR_Q8_0_Q8_1_MMVQ; + case GGML_TYPE_Q2_K : return VDR_Q2_K_Q8_1_MMVQ; + case GGML_TYPE_Q3_K : return VDR_Q3_K_Q8_1_MMVQ; + case GGML_TYPE_Q4_K : return VDR_Q4_K_Q8_1_MMVQ; + case GGML_TYPE_Q5_K : return VDR_Q5_K_Q8_1_MMVQ; + case GGML_TYPE_Q6_K : return VDR_Q6_K_Q8_1_MMVQ; + case GGML_TYPE_IQ2_XXS : return VDR_IQ2_XXS_Q8_1_MMVQ; + case GGML_TYPE_IQ2_XS : return VDR_IQ2_XS_Q8_1_MMVQ; + case GGML_TYPE_IQ2_S : return VDR_IQ2_S_Q8_1_MMVQ; + case GGML_TYPE_IQ3_XXS : return VDR_IQ3_XXS_Q8_1_MMVQ; + case GGML_TYPE_IQ3_S : return VDR_IQ3_S_Q8_1_MMVQ; + case GGML_TYPE_IQ4_NL : return VDR_IQ4_NL_Q8_1_MMVQ; + case GGML_TYPE_IQ4_XS : return VDR_IQ4_XS_Q8_1_MMVQ; + default : return 1; + } } template From f176122a3d50c781414458b498b9426086a91647 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 25 Apr 2025 09:21:03 +0200 Subject: [PATCH 095/132] Fix LLaMA-4 attention (#342) Co-authored-by: Iwan Kawrakow --- src/llama.cpp | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index d0fd9c48..c870b09e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9974,7 +9974,12 @@ struct llm_build_context { } // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + //bool is_swa = hparams.n_swa > 0 && h_params.n_swa_pattern > 0 ? + ggml_tensor * KQ_mask = build_inp_KQ_mask(); + ggml_tensor * KQ_mask_swa = nullptr; + if (hparams.n_swa > 0 && hparams.n_swa_pattern > 0) { + KQ_mask_swa = build_inp_KQ_mask_swa(); + } //const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : 1.f; @@ -9982,6 +9987,8 @@ struct llm_build_context { struct ggml_tensor * inpSA = inpL; bool use_rope = model.arch == LLM_ARCH_LLAMA4 ? (il + 1) % hparams.n_no_rope_layer_step != 0 : true; + auto this_KQ_mask = hparams.n_swa > 0 && hparams.n_swa_pattern > 0 && il % hparams.n_swa_pattern < (hparams.n_swa_pattern - 1) ? + KQ_mask_swa : KQ_mask; // norm cur = llm_build_norm(ctx0, inpL, hparams, @@ -10046,7 +10053,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); } if (il == n_layer - 1) { From 25d1a0dca87e8d0237471e6d426a971e1d5289a2 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 25 Apr 2025 11:01:08 +0200 Subject: [PATCH 096/132] Fix FA on ARM (#346) Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_mul_mat.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 29ae9d53..45d804a4 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -16142,7 +16142,13 @@ struct FlashQKV { std::memcpy(S, fms.S, nq1*sizeof(float)); auto R = qkv_cache; for (int j = 0; j < nq1; ++j) { +#ifdef __aarch64__ + for (int i = 0; i < D/F16::block_size; ++i) { + F16::store(qkv + F16::block_size*i, F16::load(R + F16::block_size*i)); + } +#else std::memcpy(qkv, R, D*sizeof(float)); +#endif qkv += stride_qkv; R += D; } @@ -16162,7 +16168,13 @@ struct FlashQKV { std::memcpy(S, fms.S, q_step*sizeof(float)); auto R = qkv_cache; for (int j = 0; j < q_step; ++j) { +#ifdef __aarch64__ + for (int i = 0; i < D/F16::block_size; ++i) { + F16::store(qkv + F16::block_size*i, F16::load(R + F16::block_size*i)); + } +#else std::memcpy(qkv, R, D*sizeof(float)); +#endif qkv += stride_qkv; R += D; } From c817160d0395ddbc2ad79d3c0ce52a4da63ac81d Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 25 Apr 2025 13:24:18 +0200 Subject: [PATCH 097/132] Add ability to manually set arch flags (#347) Co-authored-by: Iwan Kawrakow --- ggml/src/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 0aa08adf..17a9e2eb 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -1086,7 +1086,7 @@ if (NOT MSVC) endif() endif() -set(ARCH_FLAGS "") +set(ARCH_FLAGS ${GGML_ARCH_FLAGS}) if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR CMAKE_GENERATOR_PLATFORM_LWR STREQUAL "arm64" OR From 770892086c15471a397a6a1a196986de906cdc91 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 25 Apr 2025 19:48:08 +0200 Subject: [PATCH 098/132] Fix q4_1 and q5_1 on Arm (#348) Co-authored-by: Iwan Kawrakow --- ggml/src/ggml.c | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 83a48cb6..ad9393cc 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -741,7 +741,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref, .vec_dot = ggml_vec_dot_q4_1_q8_1, #if GGML_USE_IQK_MULMAT +#if defined __AVX2__ .vec_dot_type = GGML_TYPE_Q8_2_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#endif #else .vec_dot_type = GGML_TYPE_Q8_1, #endif @@ -809,7 +813,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float_ref = (ggml_from_float_t) quantize_row_q5_1_ref, .vec_dot = ggml_vec_dot_q5_1_q8_1, #if GGML_USE_IQK_MULMAT +#ifdef __AVX2__ .vec_dot_type = GGML_TYPE_Q8_2_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#endif #else .vec_dot_type = GGML_TYPE_Q8_1, #endif From 715fc552ad2ea5fad38e7ff856bf84fdb71b692e Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 26 Apr 2025 08:13:25 +0200 Subject: [PATCH 099/132] Add support for Cohere2 (#341) * Add support for Cohere2 * Fixe IQ4_NL on AVX2 * Command-A needs fp32 precision for K*Q --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml.c | 2 +- ggml/src/iqk/iqk_mul_mat.cpp | 49 ++++++++- src/llama.cpp | 203 ++++++++++++++++++++++++++++++++++- 3 files changed, 247 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index ad9393cc..88013f74 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1289,7 +1289,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref, .vec_dot = ggml_vec_dot_iq4_nl_q8_0, #if GGML_USE_IQK_MULMAT -#if defined __AVX2__ +#if defined HAVE_FANCY_SIMD .vec_dot_type = GGML_TYPE_Q8_2_X4, #else .vec_dot_type = GGML_TYPE_Q8_0_X4, diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 45d804a4..e7ab2e5b 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1750,6 +1750,15 @@ __m256i inline load_iq4nl_values_256() { return MM256_SET_M128I(val128, val128); } +__m128i inline load_iq4k_values_128() { + return _mm_loadu_si128((const __m128i *)iq4k_values); +} + +__m256i inline load_iq4k_values_256() { + auto val128 = load_iq4k_values_128(); + return MM256_SET_M128I(val128, val128); +} + #ifdef HAVE_FANCY_SIMD //====================================== Zen4 ================================================== @@ -8519,7 +8528,11 @@ struct Q4_0_1_Dequantizer { struct IQ4_NL_Dequantizer { Dequantizer4bit b4; +#ifdef HAVE_FANCY_SIMD const __m256i values = load_iq4nl_values_256(); +#else + const __m256i values = load_iq4k_values_256(); +#endif inline __m256i dequant(const block_iq4_nl * x) const { return _mm256_shuffle_epi8(values, b4.dequant(x->qs)); } @@ -8630,11 +8643,19 @@ struct Q4_0_1_Unpacker final : public Q_Unpacker using Sum4T = Sum4TypeQ82; inline static int block_size() { return QK4_0; } }; +#ifdef HAVE_FANCY_SIMD struct IQ4_NL_Unpacker final : public Q_Unpacker, IQ4_NL_Dequantizer> { IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} using Sum4T = Sum4TypeQ82; inline static int block_size() { return QK4_NL; } }; +#else +struct IQ4_NL_Unpacker final : public Q_Unpacker { + IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} + using Sum4T = Sum4TypeQ80; + inline static int block_size() { return QK4_NL; } +}; +#endif struct Q5_0_Unpacker final : public Q_Unpacker { Q5_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} using Sum4T = Sum4TypeQ80; @@ -9155,9 +9176,29 @@ template void MulMat::set_functions(MulMat& m) { m.funcs[6] = mul_mat_qX_1_q8_2_T; m.funcs[7] = mul_mat_qX_1_q8_2_T; } + else if constexpr (std::is_same_v) { +#ifdef HAVE_FANCY_SIMD + m.funcs[0] = mul_mat_qX_1_q8_2_T; + m.funcs[1] = mul_mat_qX_1_q8_2_T; + m.funcs[2] = mul_mat_qX_1_q8_2_T; + m.funcs[3] = mul_mat_qX_1_q8_2_T; + m.funcs[4] = mul_mat_qX_1_q8_2_T; + m.funcs[5] = mul_mat_qX_1_q8_2_T; + m.funcs[6] = mul_mat_qX_1_q8_2_T; + m.funcs[7] = mul_mat_qX_1_q8_2_T; +#else + m.funcs[0] = mul_mat_qX_0_q8_0_T; + m.funcs[1] = mul_mat_qX_0_q8_0_T; + m.funcs[2] = mul_mat_qX_0_q8_0_T; + m.funcs[3] = mul_mat_qX_0_q8_0_T; + m.funcs[4] = mul_mat_qX_0_q8_0_T; + m.funcs[5] = mul_mat_qX_0_q8_0_T; + m.funcs[6] = mul_mat_qX_0_q8_0_T; + m.funcs[7] = mul_mat_qX_0_q8_0_T; +#endif + } else if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v) { + std::is_same_v || std::is_same_v) { m.funcs[0] = mul_mat_qX_1_q8_2_T; m.funcs[1] = mul_mat_qX_1_q8_2_T; m.funcs[2] = mul_mat_qX_1_q8_2_T; @@ -9476,7 +9517,11 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { case GGML_TYPE_IQ4_NL: assert (ne00 % QK4_NL == 0); MulMat::set_functions(mm); +#ifdef HAVE_FANCY_SIMD expected_typeB = GGML_TYPE_Q8_2_X4; +#else + expected_typeB = GGML_TYPE_Q8_0_X4; +#endif break; case GGML_TYPE_IQ4_NL_R4: assert (ne00 % QK4_NL == 0); diff --git a/src/llama.cpp b/src/llama.cpp index c870b09e..26ddeb2e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -229,6 +229,7 @@ enum llm_arch { LLM_ARCH_JAIS, LLM_ARCH_GRANITE = 46, LLM_ARCH_GRANITE_MOE, + LLM_ARCH_COHERE2, LLM_ARCH_UNKNOWN, }; @@ -279,6 +280,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_JAIS, "jais" }, { LLM_ARCH_GRANITE, "granite" }, { LLM_ARCH_GRANITE_MOE, "granitemoe" }, + { LLM_ARCH_COHERE2, "cohere2" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1456,7 +1458,21 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, - + { + LLM_ARCH_COHERE2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -2539,6 +2555,7 @@ struct llama_hparams { if (this->n_layer != other.n_layer) return true; if (this->n_rot != other.n_rot) return true; if (this->n_swa != other.n_swa) return true; + if (this->n_swa_pattern != other.n_swa_pattern) return false; if (this->n_embd_head_k != other.n_embd_head_k) return true; if (this->n_embd_head_v != other.n_embd_head_v) return true; if (this->n_expert != other.n_expert) return true; @@ -5797,6 +5814,17 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_COHERE2: + { + hparams.n_swa_pattern = 4; + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_8B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -6406,6 +6434,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); + LLAMA_LOG_INFO("%s: n_swa_pattern = %u\n", __func__, hparams.n_swa_pattern); LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k); LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v); LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); @@ -8397,6 +8426,34 @@ static bool llm_load_tensors( layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); } } break; + case LLM_ARCH_COHERE2: + { + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + // init output from the input tok embed + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, + llama_model_loader::TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = model.layers[i]; + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + } + } + break; default: throw std::runtime_error("unknown architecture"); } @@ -9340,7 +9397,7 @@ static struct ggml_tensor * llm_build_kqv( // For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG. // Not sure if it is really a matter of insufficient precision, or I have made a mistake in the fattn-vec-f16 kernel. if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || - (model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8)) { + (model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2) { ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); } //ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); @@ -9364,7 +9421,8 @@ static struct ggml_tensor * llm_build_kqv( //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) { + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || + model.arch == LLM_ARCH_COHERE2) { // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 ggml_mul_mat_set_prec(kq, GGML_PREC_F32); @@ -9423,7 +9481,8 @@ static struct ggml_tensor * llm_build_kqv( auto k_i = ggml_view_3d(ctx, k, k->ne[0], k->ne[1], this_ne12, k->nb[1], k->nb[2], k->nb[2]*i02); auto q_i = ggml_view_3d(ctx, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i12); auto kq_i = ggml_mul_mat(ctx, k_i, q_i); - if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) { + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || + model.arch == LLM_ARCH_COHERE2) { ggml_mul_mat_set_prec(kq_i, GGML_PREC_F32); } if (model.arch == LLM_ARCH_GROK) { @@ -15013,6 +15072,137 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_cohere2() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + const float f_logit_scale = hparams.f_logit_scale; + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + // cohere2 requires different mask for layers using sliding window (SWA) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(); + + // sliding window switch pattern + const int32_t sliding_window_pattern = 4; + + for (int il = 0; il < n_layer; ++il) { + // three layers sliding window attention (window size 4096) and ROPE + // fourth layer uses global attention without positional embeddings + const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1); + struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + struct ggml_tensor * ffn_inp = cur; + + // self-attention + { + // rope freq factors for 128k context + struct ggml_tensor * rope_factors = build_rope_factors(il); + + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + if (is_sliding) { + Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, + beta_fast, beta_slow); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, + attn_factor, beta_fast, beta_slow); + cb(Kcur, "Kcur", il); + } else { + // For non-sliding layers, just reshape without applying RoPE + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + cb(Qcur, "Qcur", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + cb(Kcur, "Kcur", il); + } + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, + KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids); + } + + struct ggml_tensor * attn_out = cur; + + // feed-forward network + { + cur = llm_build_ffn(ctx0, lctx, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, + NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, + cb, il); + cb(cur, "ffn_out", il); + } + + // add together residual + FFN + self-attention + cur = ggml_add(ctx0, cur, inpL); + cur = ggml_add(ctx0, cur, attn_out); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + + if (f_logit_scale) { + cur = ggml_scale(ctx0, cur, f_logit_scale); + } + + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + struct ggml_cgraph * build_t5_encoder() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -15813,6 +16003,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_bitnet_25(); } break; + case LLM_ARCH_COHERE2: + { + result = llm.build_cohere2(); + } break; case LLM_ARCH_T5: { if (lctx.is_encoding) { @@ -19486,6 +19680,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_CHATGLM: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_COHERE2: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 From 9e846f0eb196ed543cb29753bfd6a21a936a5138 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 26 Apr 2025 09:19:43 +0200 Subject: [PATCH 100/132] Fix division by zero bug (#349) Co-authored-by: Iwan Kawrakow --- ggml/src/ggml.c | 2 +- ggml/src/iqk/iqk_flash_attn.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 88013f74..f3cfd9a0 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -21790,7 +21790,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa const struct ggml_tensor * k = node->src[1]; if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) { if (k->ne[2] > 1) { - int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks)); + int nk = MAX(1, 32 * (k->ne[2]*k->ne[1]/(32*n_tasks))); int nstep_k = k->ne[2]*k->ne[1]/nk; size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float); size_t size = nstep_k*result_size; diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp index 639b79af..0de68b94 100644 --- a/ggml/src/iqk/iqk_flash_attn.cpp +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -154,7 +154,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float } if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) { - int nk = 32 * (nek2*nek1/(32*nth)); + int nk = std::max(1, 32 * (nek2*nek1/(32*nth))); int nkk = (nek1 + nk - 1)/nk; int nstep_k = nek2*nkk; auto result_size = (Dv + 16)*rk2*sizeof(float); From baeefb4731fb24cdace168f6dbc74516d470efc0 Mon Sep 17 00:00:00 2001 From: ubergarm Date: Sat, 26 Apr 2025 11:34:04 -0400 Subject: [PATCH 101/132] Add GLM-4-0414 Model Support (#344) * Add GLM-4-0414 model support Based on zRzRzRzRzRzRzR's PR on mainline llama.cpp. Still some issues where it doesn't work: * offloading >=60 layers to GPU * no flash attention * Remove seemingly unused llm_tensor enums Both of these seem unused and LLM_TENSOR_ATTN_POST_NORM already existed which seems pretty similar? Don't think they were used in the python code either... So removed these as possibly just cruft: * LLM_TENSOR_POST_ATTN_NORM * LLM_TENSOR_POST_MLP_NORM * Set flash attention precision to f32 on GLM4 arch * Set non flash attention precision to f32 on GLM4 * Remove reshape_3d() for Vcur in build_glm4() This fixes the non-flash-attention inferencing on both CPU and CUDA. --- src/llama.cpp | 222 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 218 insertions(+), 4 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 26ddeb2e..e5988fd4 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -222,6 +222,7 @@ enum llm_arch { LLM_ARCH_ARCTIC, LLM_ARCH_DEEPSEEK2, LLM_ARCH_CHATGLM, + LLM_ARCH_GLM4, LLM_ARCH_BITNET, LLM_ARCH_BITNET_25, LLM_ARCH_T5, @@ -273,6 +274,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_ARCTIC, "arctic" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_CHATGLM, "chatglm" }, + { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_BITNET_25, "bitnet-25" }, { LLM_ARCH_T5, "t5" }, @@ -1309,6 +1311,25 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, }, }, + { + LLM_ARCH_GLM4, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, { LLM_ARCH_BITNET, { @@ -2423,6 +2444,7 @@ enum e_model { MODEL_16B, MODEL_20B, MODEL_30B, + MODEL_32B, MODEL_34B, MODEL_35B, MODEL_40B, @@ -5025,6 +5047,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_16B: return "16B"; case MODEL_20B: return "20B"; case MODEL_30B: return "30B"; + case MODEL_32B: return "32B"; case MODEL_34B: return "34B"; case MODEL_35B: return "35B"; case MODEL_40B: return "40B"; @@ -5730,6 +5753,15 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_GLM4: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 40: model.type = e_model::MODEL_9B; break; + case 61: model.type = e_model::MODEL_32B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_BITNET: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -6046,6 +6078,7 @@ static void llm_load_vocab( vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO; vocab.tokenizer_clean_spaces = false; } else if ( + tokenizer_pre == "glm4" || tokenizer_pre == "chatglm-bpe") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4; vocab.special_bos_id = -1; @@ -8454,6 +8487,48 @@ static bool llm_load_tensors( } } break; + case LLM_ARCH_GLM4: + { + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + + if (layer.wqkv == nullptr) { + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + } + + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); + + layer.ffn_post_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -9397,7 +9472,7 @@ static struct ggml_tensor * llm_build_kqv( // For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG. // Not sure if it is really a matter of insufficient precision, or I have made a mistake in the fattn-vec-f16 kernel. if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || - (model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2) { + (model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4) { ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); } //ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); @@ -9422,7 +9497,7 @@ static struct ggml_tensor * llm_build_kqv( //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || - model.arch == LLM_ARCH_COHERE2) { + model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4) { // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 ggml_mul_mat_set_prec(kq, GGML_PREC_F32); @@ -9482,7 +9557,7 @@ static struct ggml_tensor * llm_build_kqv( auto q_i = ggml_view_3d(ctx, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i12); auto kq_i = ggml_mul_mat(ctx, k_i, q_i); if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || - model.arch == LLM_ARCH_COHERE2) { + model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4) { ggml_mul_mat_set_prec(kq_i, GGML_PREC_F32); } if (model.arch == LLM_ARCH_GROK) { @@ -10033,7 +10108,7 @@ struct llm_build_context { } // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - //bool is_swa = hparams.n_swa > 0 && h_params.n_swa_pattern > 0 ? + //bool is_swa = hparams.n_swa > 0 && h_params.n_swa_pattern > 0 ? ggml_tensor * KQ_mask = build_inp_KQ_mask(); ggml_tensor * KQ_mask_swa = nullptr; if (hparams.n_swa > 0 && hparams.n_swa_pattern > 0) { @@ -15745,6 +15820,140 @@ struct llm_build_context { return gf; } + + struct ggml_cgraph * build_glm4() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // Pre-attention norm + cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + struct ggml_tensor * Qcur = nullptr; + struct ggml_tensor * Kcur = nullptr; + struct ggml_tensor * Vcur = nullptr; + + if (model.layers[il].wqkv == nullptr) { + Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + } else { + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + if (model.layers[il].bqkv) { + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + } + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Post-attention norm (new!) + cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "post_attn_norm", il); + + // Add the input (residual connection after post-attention norm) + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + // Pre-MLP norm + cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + // MLP + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); + + // Post-MLP norm + cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "post_mlp_norm", il); + } + + // Add residual connection after post-MLP norm + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + + // Final norm + cur = llm_build_norm(ctx0, inpL, hparams, model.output_norm, NULL, LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // Output projection + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -15995,6 +16204,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_chatglm(); } break; + case LLM_ARCH_GLM4: + { + result = llm.build_glm4(); + } break; case LLM_ARCH_BITNET: { result = llm.build_bitnet(); @@ -19678,6 +19891,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_ARCTIC: case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_CHATGLM: + case LLM_ARCH_GLM4: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_COHERE2: From cda24b58cbef34154651d0083910fed860a506c1 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 29 Apr 2025 07:19:43 +0200 Subject: [PATCH 102/132] CPU FA improvements (#351) * FA: provide work buffer for K repacking * Add header to avoid comp0iler warnings * WIP * WIP * WIP * WIP * Slightly better * WIP (Zen4) * WIP * Try to improve for unusual number of heads/number of threads * Use mul_mat_qX_0_q8_2_Tx for q6_0 in FA * Use mul_mat_qX_0_q8_2_Tx for q4_0 in FA * Use Sum4q4 for q4_0 * WIP * WIP * Much better FA TG with q8_0 KV cache Just repack it even for TG. But do the repacking for k_step rows, not the whole K tensor. --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml.c | 29 +- ggml/src/iqk/iqk_flash_attn.cpp | 147 +++++-- ggml/src/iqk/iqk_flash_impl.h | 4 + ggml/src/iqk/iqk_mul_mat.cpp | 714 ++++++++++++++++++++++++++++---- 4 files changed, 766 insertions(+), 128 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f3cfd9a0..4cd18a28 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -21786,15 +21786,36 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread #if GGML_USE_IQK_MULMAT + size_t qsize = 0; const struct ggml_tensor * q = node->src[0]; const struct ggml_tensor * k = node->src[1]; + if (k->type == GGML_TYPE_Q8_0) { + qsize = ggml_nrows(k)*ggml_row_size(k->type, k->ne[0]); + } if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) { if (k->ne[2] > 1) { - int nk = MAX(1, 32 * (k->ne[2]*k->ne[1]/(32*n_tasks))); + int gcd = simple_gcd(k->ne[2], n_tasks); + int nth_k = n_tasks/gcd; + int nek2_k = k->ne[2]/gcd; + int nchunk = nek2_k*k->ne[1]/32; + int npt = (nchunk + nth_k - 1)/nth_k; + int nk; + if (npt*nth_k == nchunk) { + nk = 32 * (k->ne[1]*k->ne[2]/(32*n_tasks)); + } else { + //int nm = std::max(1, npt/8); + int nm = 1; + while (true) { + if (nm*4 >= npt) break; + nm *= 2; + } + nk = 32*nm; + } + //int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks)); int nstep_k = k->ne[2]*k->ne[1]/nk; size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float); size_t size = nstep_k*result_size; - cur = MAX(cur, size); + cur = MAX(cur, size+qsize); } else { int nstep_k = k->ne[1]/32; int gcd_k = simple_gcd(nstep_k, n_tasks); @@ -21808,9 +21829,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]); size += q->ne[2]*row_size; } - cur = MAX(cur, size); + cur = MAX(cur, size+qsize); } } + } else { + cur = MAX(cur, qsize); } #endif } break; diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp index 0de68b94..fd0d5dd0 100644 --- a/ggml/src/iqk/iqk_flash_attn.cpp +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -25,6 +25,24 @@ inline uint32_t simple_gcd(uint32_t a, uint32_t b) { } return a; } +inline void accumulate_qkv(int Dv, float& M, float& S, float Mj, float Sj, float * Racc, const float * R) { + if (Mj == -INFINITY) return; + if (Mj > M) { + if (M == -INFINITY) { + std::memcpy(Racc, R, Dv*sizeof(float)); + S = Sj; + } else { + float c = exp(M - Mj); + S = c*S + Sj; + for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i]; + } + M = Mj; + } else { + float c = exp(Mj - M); + S += c*Sj; + for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i]; + } +} } // TODO: get the ggml_type enum here without polution @@ -34,7 +52,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float int nek3, int nek2, long nbk3, long nbk2, int nev3, int nev2, long nbv3, long nbv2, int ne2, int ne1, long nb1, - int int_type_k, // type of k + int int_type_k_in, // type of k int int_type_v, // type of v int Dk, // K head size int Dv, // V head size @@ -51,7 +69,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float float scale, // scale applied before softmax float softcap, // if > 0, a "soft-cap" operation is applied before softmax float * qkv, // v*softmax(scale*(k*q)) - [[maybe_unused]] void * work_buffer, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data, + [[maybe_unused]] void * work_buffer_in, [[maybe_unused]] barrier_t barrier, [[maybe_unused]] void * barrier_data, int ith, int nth) { if (type_q != 0 || type_mask != 1 || max_bias > 0) return false; @@ -61,6 +79,29 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float int rk3 = neq3/nek3; int rv3 = neq3/nev3; + int int_type_k = int_type_k_in; + auto work_buffer = work_buffer_in; + if (neq1 >= 8 || rk2 >= 8) { + uint64_t row_size = 0; + work_buffer = iqk_repack_k(int_type_k, Dk, nek1, nek2, nek3, stride_k, nbk2, nbk3, k, work_buffer_in, ith, nth, int_type_k, row_size); + if (int_type_k != int_type_k_in) { + stride_k = row_size; + nbk2 = stride_k*nek1; + nbk3 = nbk2*nek2; + k = work_buffer_in; + barrier(barrier_data); + } + } + //uint64_t row_size = 0; + //auto work_buffer = iqk_repack_k(int_type_k, Dk, nek1, nek2, nek3, stride_k, nbk2, nbk3, k, work_buffer_in, ith, nth, int_type_k, row_size); + //if (int_type_k != int_type_k_in) { + // stride_k = row_size; + // nbk2 = stride_k*nek1; + // nbk3 = nbk2*nek2; + // k = work_buffer_in; + // barrier(barrier_data); + //} + // Getting confused all the time about where to load data from and store the results to // (especially when combining the results from the threads). // So, for now, making it work just for MLA (nek2 = 1). @@ -128,22 +169,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float auto Mj = R + Dv*nq_this_j; auto Sj = Mj + nq_this_j; R += jj*Dv; - if (Mj[jj] == -INFINITY) continue; - if (Mj[jj] > M) { - if (M == -INFINITY) { - std::memcpy(Racc, R, Dv*sizeof(float)); - S = Sj[jj]; - } else { - float c = exp(M - Mj[jj]); - S = c*S + Sj[jj]; - for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i]; - } - M = Mj[jj]; - } else { - float c = exp(Mj[jj] - M); - S += c*Sj[jj]; - for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i]; - } + accumulate_qkv(Dv, M, S, Mj[jj], Sj[jj], Racc, R); } float norm = S > 0 ? 1/S : 1; for (int i = 0; i < Dv; ++i) Racc[i] *= norm; @@ -154,10 +180,72 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float } if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) { - int nk = std::max(1, 32 * (nek2*nek1/(32*nth))); + auto result_size = (Dv + 16)*rk2*sizeof(float); + int gcd = simple_gcd(nek2, nth); + if (false && gcd > 1) { + int nth_g = nth/gcd; + int ith_g = ith%nth_g; + int nek1_32 = nek1/32; + int nek1_pt = (nek1_32 + nth_g - 1)/nth_g; + int ith_mid = nth_g; + if (nek1_pt*nth_g > nek1_32) { + ith_mid = nek1_32 - nth_g*(nek1_pt - 1); + } + nek1_pt *= 32; + int nek1_mid = ith_mid*nek1_pt; + int nek1_thread = ith_g < ith_mid ? nek1_pt : nek1_pt - 32; + for (int ik02 = ith/nth_g; ik02 < nek2; ik02 += gcd) { + int ik01 = ith_g < ith_mid ? ith_g*nek1_pt : nek1_mid + (ith_g - ith_mid)*nek1_thread; + auto this_result = (float *)((char *)work_buffer + (ik02*nth_g + ith_g)*result_size); + auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2); + auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2; + auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2; + auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here + if (!iqk_flash_attn_impl(int_type_k, int_type_v, + Dk, Dv, rk2, nek1_thread, nbq2, stride_k, stride_v, 0, Dv, + this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m, + scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false; + } + + barrier(barrier_data); + + for (int iq2 = ith; iq2 < neq2; iq2 += nth) { + int ik02 = iq2/rk2; + int il = iq2 - ik02*rk2; + auto Racc = qkv + iq2*nb1/sizeof(float); + float M = -INFINITY, S = 0; + for (int ig = 0; ig < nth_g; ++ig) { + int istep_k = ik02*nth_g + ig; + auto this_result = (float *)((char *)work_buffer + istep_k*result_size); + const float * R = this_result + il*Dv; + const float * Mj = this_result + Dv*rk2; + const float * Sj = Mj + rk2; + accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R); + } + float norm = S > 0 ? 1/S : 1; + for (int i = 0; i < Dv; ++i) Racc[i] *= norm; + } + return true; + } + int nth_k = nth/gcd; + int nek2_k = nek2/gcd; + int nchunk = nek2_k*nek1/32; + int npt = (nchunk + nth_k - 1)/nth_k; + int nk; + if (npt*nth_k == nchunk) { + nk = 32 * (nek2*nek1/(32*nth)); + } else { + //int nm = std::max(1, npt/8); + int nm = 1; + while (true) { + if (nm*4 >= npt) break; + nm *= 2; + } + nk = 32*nm; + } + //int nk = 32 * (nek2*nek1/(32*nth)); int nkk = (nek1 + nk - 1)/nk; int nstep_k = nek2*nkk; - auto result_size = (Dv + 16)*rk2*sizeof(float); //if (ith == 0) printf("rk2 = %d, nek1 = %d, nek2 = %d, nk = %d, nkk = %d, nstep_k = %d\n", (int)rk2, (int)nek1, (int)nek2, nk, nkk, nstep_k); for (int istep_k = ith; istep_k < nstep_k; istep_k += nth) { int ik02 = istep_k/nkk; @@ -183,7 +271,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float int ik02 = iq2/rk2; int il = iq2 - ik02*rk2; auto Racc = qkv + iq2*nb1/sizeof(float); - std::memset(Racc, 0, Dv*sizeof(float)); + //std::memset(Racc, 0, Dv*sizeof(float)); float M = -INFINITY, S = 0; for (int ikk = 0; ikk < nkk; ++ikk) { int istep_k = ik02*nkk + ikk; @@ -191,22 +279,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float const float * R = this_result + il*Dv; const float * Mj = this_result + Dv*rk2; const float * Sj = Mj + rk2; - if (Mj[il] == -INFINITY) continue; - if (Mj[il] > M) { - if (M == -INFINITY) { - std::memcpy(Racc, R, Dv*sizeof(float)); - S = Sj[il]; - } else { - float c = exp(M - Mj[il]); - S = c*S + Sj[il]; - for (int i = 0; i < Dv; ++i) Racc[i] = c*Racc[i] + R[i]; - } - M = Mj[il]; - } else { - float c = exp(Mj[il] - M); - S += c*Sj[il]; - for (int i = 0; i < Dv; ++i) Racc[i] += c*R[i]; - } + accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R); } float norm = S > 0 ? 1/S : 1; for (int i = 0; i < Dv; ++i) Racc[i] *= norm; diff --git a/ggml/src/iqk/iqk_flash_impl.h b/ggml/src/iqk/iqk_flash_impl.h index 68802927..6f62e56b 100644 --- a/ggml/src/iqk/iqk_flash_impl.h +++ b/ggml/src/iqk/iqk_flash_impl.h @@ -6,6 +6,8 @@ #pragma once +#include + bool iqk_flash_attn_impl(int type_k, // type of k int type_v, // type of v int Dk, // K head size @@ -27,3 +29,5 @@ bool iqk_flash_attn_impl(int type_k, // type of k float * M, float * S); +void * iqk_repack_k(int type_k, int nek0, int nek1, int nek2, int nek3, long nbk1, long nbk2, long nbk3, + const void * k, void * work, int ith, int nth, int& repacked_type, uint64_t& row_size); diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index e7ab2e5b..5f916584 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -19,6 +19,7 @@ #include "ggml-quants.h" #include "iqk_mul_mat.h" #include "iqk_quantize.h" +#include "iqk_flash_impl.h" #define GGML_COMMON_IMPL_C #include "ggml-common.h" @@ -6639,6 +6640,84 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI } } +#ifdef HAVE_FANCY_SIMD +template +static void mul_mat_q8_KV_q8_KV_8(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + GGML_ASSERT(n%32 == 0); + __m512i qx[4]; + __m512i acc[nrc_y <= 4 ? 2*nrc_y : nrc_y] = {}; + float dy[nrc_y]; + int32_t sy[nrc_y]; + const int8_t * q8y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + dy[iy] = dptr[0]; + auto iptr = (const int32_t *)(dptr + 1); + sy[iy] = -64*iptr[0]; + q8y[iy] = (const int8_t *)(dptr + 2); + } + const int8_t * q8x[8]; + float dx[8]; + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int kx = 0; kx < 8; ++kx) { + auto dptr = (const float *)((const char *)vx + (ix+kx)*bx); + dx[kx] = dptr[0]; + q8x[kx] = (const int8_t *)(dptr + 2); + } + for (int i = 0; i < n/32; ++i) { + for (int kx = 0; kx < 4; ++kx) { + qx[kx] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8x[kx+0] + i)), + _mm256_loadu_si256((const __m256i *)q8x[kx+4] + i), 1); + } + auto t0 = _mm512_unpacklo_epi32(qx[0], qx[1]); + auto t1 = _mm512_unpacklo_epi32(qx[2], qx[3]); + auto t2 = _mm512_unpackhi_epi32(qx[0], qx[1]); + auto t3 = _mm512_unpackhi_epi32(qx[2], qx[3]); + qx[0] = _mm512_xor_si512(_mm512_unpacklo_epi64(t0, t1), _mm512_set1_epi8(-128)); + qx[1] = _mm512_xor_si512(_mm512_unpackhi_epi64(t0, t1), _mm512_set1_epi8(-128)); + qx[2] = _mm512_xor_si512(_mm512_unpacklo_epi64(t2, t3), _mm512_set1_epi8(-128)); + qx[3] = _mm512_xor_si512(_mm512_unpackhi_epi64(t2, t3), _mm512_set1_epi8(-128)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y256 = _mm256_loadu_si256((const __m256i *)q8y[iy] + i); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1); + if constexpr (nrc_y <= 4) { + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + } else { + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + } + } + } + auto scales_x = _mm256_loadu_ps(dx); + for (int iy = 0; iy < nrc_y; ++iy) { + if constexpr (nrc_y <= 4) { + auto ss = _mm512_add_epi32(_mm512_add_epi32(acc[2*iy+0], acc[2*iy+1]), _mm512_set1_epi32(sy[iy])); + auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 0), _mm512_extracti32x4_epi32(ss, 1)); + auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 2), _mm512_extracti32x4_epi32(ss, 3)); + auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy])); + info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1))); + info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2))); + acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512(); + } else { + acc[iy] = _mm512_add_epi32(acc[iy], _mm512_set1_epi32(sy[iy])); + auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 0), _mm512_extracti32x4_epi32(acc[iy], 1)); + auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 2), _mm512_extracti32x4_epi32(acc[iy], 3)); + auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy])); + info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1))); + info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2))); + acc[iy] = _mm512_setzero_si512(); + } + } + } +} +#endif + template static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); @@ -8208,6 +8287,22 @@ template struct return _mm256_add_epi32(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,1,2,3, 0,1,2,3 } } + inline __m256i compute(__m256i x, __m256i y) const { return dot.compute(x, y); } +}; + +template struct Sum4q4 { + inline __m256i compute(const __m256i * qx, const Q8 * y) const { + const Q8x4 * y4 = (const Q8x4 *)y; + auto p0 = _mm256_maddubs_epi16(qx[0], _mm256_loadu_si256((const __m256i *)y4->qs+0)); // 16x block 0 + auto p1 = _mm256_maddubs_epi16(qx[1], _mm256_loadu_si256((const __m256i *)y4->qs+1)); // 16x block 1 + auto p2 = _mm256_maddubs_epi16(qx[2], _mm256_loadu_si256((const __m256i *)y4->qs+2)); // 16x block 2 + auto p3 = _mm256_maddubs_epi16(qx[3], _mm256_loadu_si256((const __m256i *)y4->qs+3)); // 16x block 3 + auto p01 = _mm256_add_epi16(_mm256_unpacklo_epi32(p0, p1), _mm256_unpackhi_epi32(p0, p1)); // 0,0, 1,1, 0,0, 1,1, 0,0, 1,1, 0,0, 1,1 + auto p23 = _mm256_add_epi16(_mm256_unpacklo_epi32(p2, p3), _mm256_unpackhi_epi32(p2, p3)); // 2,2, 3,3, 2,2, 3,3, 2,2, 3,3, 2,2, 3,3 + auto p0123 = _mm256_add_epi16(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,0, 1,1, 2,2, 3,3, 0,0, 1,1, 2,2, 3,3 + return _mm256_madd_epi16(_mm256_set1_epi16(1), p0123); + } + inline __m256i compute(__m256i x, __m256i y) const { return _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(x, y)); } }; struct ScaleHelperQ8_0 { @@ -8362,6 +8457,7 @@ struct MinusType0 { inline __m256 compute(__m128 d, int) const { return _mm256_set_m128(d, d); } inline float compute(float d, int) const { return d; } inline float result(__m256 acc, int) const { return hsum_float_8(acc); } + inline __m256 vresult(__m256 acc, int) const { return acc; } }; template struct MinusType1 { @@ -8381,6 +8477,9 @@ template struct MinusType1 { const __m128 sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1)); return hsum_float_4(_mm_add_ps(sum, accm[iy])); } + inline __m256 vresult(__m256 acc, int iy) const { + return _mm256_add_ps(acc, _mm256_insertf128_ps(_mm256_setzero_ps(), accm[iy], 0)); + } }; template struct AccumT { @@ -8408,7 +8507,7 @@ template struct AccumT { for (int iy = 0; iy < nrc_y; ++iy) { auto s12 = scales.prepare1(other_scales, y[iy] + i); auto d = accm.compute(s12, iy); - const __m256i p0 = sum.dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs)); + const __m256i p0 = sum.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs)); acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]); } } @@ -8417,6 +8516,36 @@ template struct AccumT { info.store(ix, iy, accm.result(acc[iy], iy)); } } + template + inline void compute(int nb, Unpacker& unp, Scales& scales, Sum& sum, const Q8 ** y, __m256 * result) { + auto qx = unp.quants(); + __m256 dall[nrc_y]; + for (int i = 0; i < nb/4; ++i) { + auto other_scales = unp.set_block_4(i); + for (int iy = 0; iy < nrc_y; ++iy) { + auto s12 = scales.prepare4(other_scales, y[iy] + 4*i); + dall[iy] = accm.compute(s12, iy); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto pall = sum.compute(qx, y[iy] + 4*i); + acc[iy] = _mm256_fmadd_ps(dall[iy], _mm256_cvtepi32_ps(pall), acc[iy]); + } + } + if (!is_multiple_of_4) { + for (int i = 4*(nb/4); i < nb; ++i) { + auto other_scales = unp.set_block(i); + for (int iy = 0; iy < nrc_y; ++iy) { + auto s12 = scales.prepare1(other_scales, y[iy] + i); + auto d = accm.compute(s12, iy); + const __m256i p0 = sum.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs)); + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + result[iy] = accm.vresult(acc[iy], iy); + } + } }; template @@ -8425,10 +8554,7 @@ using AccumType0 = AccumT; template using AccumType1 = AccumT, nrc_y, is_multiple_of_4>; -using Sum4Type0 = Sum4; -using Sum4Type1 = Sum4; using Sum4TypeQ80 = Sum4; -//using Sum4TypeQ81 = Sum4; using Sum4TypeQ82 = Sum4; template @@ -8443,6 +8569,19 @@ void mul_mat_qX_q8_Helper(int nb, const void * vx, size_t bx, const DataInfo& in } } +template +void mul_mat_qX_q8_Helper_x2(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) { + GGML_ASSERT(nrc_x%2 == 0); + Unpacker unp(vx, bx); + typename Unpacker::Sum4T sum4; + Scales scales; + for (int ix = 0; ix < nrc_x; ix += 2) { + unp.set_row(ix); + AccumType accum; + accum.compute(nb, unp, scales, sum4, y, info, ix); + } +} + template void mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%Unpacker::block_size() == 0); @@ -8459,6 +8598,63 @@ void mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info } } +inline __m256 hsum_float_8x8(__m256 * accm) { + for (int i = 0; i < 4; ++i) { + accm[i] = _mm256_add_ps(_mm256_permute2f128_ps(accm[i], accm[i+4], 0x20), _mm256_permute2f128_ps(accm[i], accm[i+4], 0x31)); + //accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)), + // _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1))); + } + for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2])); + return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1])); +} + +template +void mul_mat_qX_0_q8_0_Tx(int n, const void * vx, size_t bx, const DataInfo& info, int) { + static_assert(8%nrc_y == 0); + Q8 q8(info); + int nb = n/Unpacker::block_size(); + Unpacker unp(vx, bx); + typename Unpacker::Sum4T sum4; + ScaleHelperQ8_0 scales; + __m256 result[8]; + auto store = [&info, &result] (int ix0) { + if constexpr (nrc_y == 1) { + info.store(ix0, 0, hsum_float_8x8(result)); + } + else if constexpr (nrc_y == 2) { + auto value = hsum_float_8x8(result); + auto value1 = _mm256_extractf128_ps(value, 1); + info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88)); + info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd)); + } + else { + float val[8]; + _mm256_storeu_ps(val, hsum_float_8x8(result)); + for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]); + } + }; + if (nb%4 == 0) { + for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { + for (int ix = 0; ix < 8/nrc_y; ++ix) { + unp.set_row(ix0 + ix); + AccumType0 accum; + accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); + } + store(ix0); + } + } else { + for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { + for (int ix = 0; ix < 8/nrc_y; ++ix) { + unp.set_row(ix0 + ix); + AccumType0 accum; + accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); + } + store(ix0); + } + } +} + + template void mul_mat_qX_1_q8_1_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%Unpacker::block_size() == 0); @@ -8491,6 +8687,52 @@ void mul_mat_qX_1_q8_2_T(int n, const void * vx, size_t bx, const DataInfo& info } } +template +void mul_mat_qX_0_q8_2_Tx(int n, const void * vx, size_t bx, const DataInfo& info, int) { + static_assert(8%nrc_y == 0); + Q8 q8(info); + int nb = n/Unpacker::block_size(); + Unpacker unp(vx, bx); + typename Unpacker::Sum4T sum4; + ScaleHelperQ8_2 scales; + __m256 result[8]; + auto store = [&info, &result] (int ix0) { + if constexpr (nrc_y == 1) { + info.store(ix0, 0, hsum_float_8x8(result)); + } + else if constexpr (nrc_y == 2) { + auto value = hsum_float_8x8(result); + auto value1 = _mm256_extractf128_ps(value, 1); + info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88)); + info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd)); + } + else { + float val[8]; + _mm256_storeu_ps(val, hsum_float_8x8(result)); + for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]); + } + }; + if (nb%4 == 0) { + for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { + for (int ix = 0; ix < 8/nrc_y; ++ix) { + unp.set_row(ix0 + ix); + AccumType1 accum; + accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); + } + store(ix0); + } + } else { + for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { + for (int ix = 0; ix < 8/nrc_y; ++ix) { + unp.set_row(ix0 + ix); + AccumType1 accum; + accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); + } + store(ix0); + } + } +} + struct Dequantizer4bit { const __m256i m4 = _mm256_set1_epi8(0xf); inline __m256i dequant(const uint8_t * qs) const { @@ -8640,7 +8882,8 @@ struct Q4_0_Unpacker final : public Q_Unpacker, Q4_0_1_Dequantizer> { Q4_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4TypeQ82; + //using Sum4T = Sum4TypeQ82; + using Sum4T = Sum4q4; inline static int block_size() { return QK4_0; } }; #ifdef HAVE_FANCY_SIMD @@ -15168,6 +15411,13 @@ struct F16 { auto v256 = _mm256_set_m128(v128, v128); return _mm512_insertf32x8(_mm512_castps256_ps512(v256), v256, 1); } + static inline void set4(const float * ptr, Data * vs) { + auto v = set4(ptr); + vs[0] = _mm512_shuffle_ps(v, v, 0x00); + vs[1] = _mm512_shuffle_ps(v, v, 0x55); + vs[2] = _mm512_shuffle_ps(v, v, 0xaa); + vs[3] = _mm512_shuffle_ps(v, v, 0xff); + } static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x00), prev); } static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x55), prev); } static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0xaa), prev); } @@ -15193,6 +15443,13 @@ struct F16 { auto v128 = _mm_loadu_ps(ptr); return _mm256_set_m128(v128, v128); } + static inline void set4(const float * ptr, Data * vs) { + auto v = set4(ptr); + vs[0] = _mm256_shuffle_ps(v, v, 0x00); + vs[1] = _mm256_shuffle_ps(v, v, 0x55); + vs[2] = _mm256_shuffle_ps(v, v, 0xaa); + vs[3] = _mm256_shuffle_ps(v, v, 0xff); + } static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x00), prev); } static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x55), prev); } static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0xaa), prev); } @@ -15388,7 +15645,119 @@ struct HelperQ80 final : public BaseHelper { } } }; +} +void * iqk_repack_k(int int_type_k, int nek0, int nek1, int nek2, int nek3, long nbk1, long nbk2, long nbk3, + const void * data, void * work, int ith, int nth, int& repacked_type, uint64_t& row_size) { + repacked_type = int_type_k; + auto type_k = ggml_type(int_type_k); + if (type_k != GGML_TYPE_Q8_0 || nek0%QK8_0 != 0) return work; + int nrows = nek1*nek2*nek3; + if (nrows%8 != 0) return work; + repacked_type = int(GGML_TYPE_Q8_0_R8); + row_size = ggml_row_size(GGML_TYPE_Q8_0, nek0); + void * result = (char *)work + nrows*row_size; + int npt = 8*((nrows/8 + nth - 1)/nth); + int first = npt*ith; + if (first >= nrows) return result; + int last = std::min(first + npt, nrows); + const block_q8_0 * x8[8]; + auto y = (block_q8_0_r8 *)((char *)work + first*row_size); + int nblock = nek0/QK8_0; +#ifdef __ARM_NEON + int8x16x2_t m0, m1, m2, m3; +#endif + for (int row = first; row < last; row += 8) { + int ik3 = row/(nek1*nek2); + int ik2 = (row - ik3*nek1*nek2)/nek1; + int ik1 = row - ik3*nek1*nek2 - ik2*nek1; + auto this_data = (const char *)data + ik1*nbk1 + ik2*nbk2 + ik3*nbk3; + for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(this_data + k*nbk1); + for (int ib = 0; ib < nblock; ++ib) { + for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d; +#ifdef __AVX2__ + auto m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4][ib].qs), _mm_loadu_si128((const __m128i *)x8[0][ib].qs)); + auto m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5][ib].qs), _mm_loadu_si128((const __m128i *)x8[1][ib].qs)); + auto m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6][ib].qs), _mm_loadu_si128((const __m128i *)x8[2][ib].qs)); + auto m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7][ib].qs), _mm_loadu_si128((const __m128i *)x8[3][ib].qs)); + auto t0 = _mm256_unpacklo_epi32(m0, m1); + auto t1 = _mm256_unpacklo_epi32(m2, m3); + auto t2 = _mm256_unpackhi_epi32(m0, m1); + auto t3 = _mm256_unpackhi_epi32(m2, m3); + m0 = _mm256_unpacklo_epi64(t0, t1); + m1 = _mm256_unpackhi_epi64(t0, t1); + m2 = _mm256_unpacklo_epi64(t2, t3); + m3 = _mm256_unpackhi_epi64(t2, t3); + //#ifdef HAVE_FANCY_SIMD + // m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); + // m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); + // m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); + // m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); + //#endif + _mm256_storeu_si256((__m256i *)y[ib].qs + 0, m0); + _mm256_storeu_si256((__m256i *)y[ib].qs + 1, m1); + _mm256_storeu_si256((__m256i *)y[ib].qs + 2, m2); + _mm256_storeu_si256((__m256i *)y[ib].qs + 3, m3); + m0 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[4][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[0][ib].qs+1)); + m1 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[5][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[1][ib].qs+1)); + m2 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[6][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[2][ib].qs+1)); + m3 = MM256_SET_M128I(_mm_loadu_si128((const __m128i *)x8[7][ib].qs+1), _mm_loadu_si128((const __m128i *)x8[3][ib].qs+1)); + t0 = _mm256_unpacklo_epi32(m0, m1); + t1 = _mm256_unpacklo_epi32(m2, m3); + t2 = _mm256_unpackhi_epi32(m0, m1); + t3 = _mm256_unpackhi_epi32(m2, m3); + m0 = _mm256_unpacklo_epi64(t0, t1); + m1 = _mm256_unpackhi_epi64(t0, t1); + m2 = _mm256_unpacklo_epi64(t2, t3); + m3 = _mm256_unpackhi_epi64(t2, t3); + //#ifdef HAVE_FANCY_SIMD + // m0 = _mm256_add_epi8(m0, _mm256_set1_epi8(127)); + // m1 = _mm256_add_epi8(m1, _mm256_set1_epi8(127)); + // m2 = _mm256_add_epi8(m2, _mm256_set1_epi8(127)); + // m3 = _mm256_add_epi8(m3, _mm256_set1_epi8(127)); + //#endif + _mm256_storeu_si256((__m256i *)y[ib].qs + 4, m0); + _mm256_storeu_si256((__m256i *)y[ib].qs + 5, m1); + _mm256_storeu_si256((__m256i *)y[ib].qs + 6, m2); + _mm256_storeu_si256((__m256i *)y[ib].qs + 7, m3); +#elif defined __ARM_NEON + for (int l = 0; l < 2; ++l) { + m0.val[0] = vld1q_s8(x8[0][ib].qs+16*l); m0.val[1] = vld1q_s8(x8[4][ib].qs+16*l); + m1.val[0] = vld1q_s8(x8[1][ib].qs+16*l); m1.val[1] = vld1q_s8(x8[5][ib].qs+16*l); + m2.val[0] = vld1q_s8(x8[2][ib].qs+16*l); m2.val[1] = vld1q_s8(x8[6][ib].qs+16*l); + m3.val[0] = vld1q_s8(x8[3][ib].qs+16*l); m3.val[1] = vld1q_s8(x8[7][ib].qs+16*l); + auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0])); + auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0])); + m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1])); + row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1])); + m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + vst1q_s8_x2(y[ib].qs + 0 + 128*l, m0); + vst1q_s8_x2(y[ib].qs + 32 + 128*l, m1); + vst1q_s8_x2(y[ib].qs + 64 + 128*l, m2); + vst1q_s8_x2(y[ib].qs + 96 + 128*l, m3); + } +#else + for (int l = 0; l < 4; ++l) { + for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) { + y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0]; + y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16]; + } + } +#endif + } + y += nblock; + } + return result; +} + +namespace { template struct HelperQ80R8 : public BaseHelper { using Base = BaseHelper; @@ -15399,24 +15768,21 @@ struct HelperQ80R8 : public BaseHelper { constexpr static int block_size_q = QK8_0; using block_q8 = block_q8_0; #endif + HelperQ80R8(const char * data, int stride) : Base(data, stride) {} HelperQ80R8(int nk, const HelperQ80& q8) : Base(q8.data, q8.stride) { r4 = repack(nk, q8); Base::data = (const char *)r4.data(); Base::stride = (D/QK8_0)*sizeof(block_q8_0); } - static std::vector repack(int nk, const HelperQ80& q8) { - static_assert(D%QK8_0 == 0); - GGML_ASSERT(nk%8 == 0); + static void repack(int nk, const char * q8_data, int q8_stride, block_q8_0_r8 * y) { constexpr int nblock = D/QK8_0; - std::vector result(nblock * nk/8); - auto y = result.data(); const block_q8_0 * x8[8]; #ifdef __ARM_NEON int8x16x2_t m0, m1, m2, m3; #endif for (int row = 0; row < nk; row += 8) { - for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(q8.data + (row + k)*q8.stride); + for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(q8_data + (row + k)*q8_stride); for (int ib = 0; ib < nblock; ++ib) { for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d; #ifdef __AVX2__ @@ -15498,6 +15864,15 @@ struct HelperQ80R8 : public BaseHelper { } y += nblock; } + } + + static std::vector repack(int nk, const HelperQ80& q8) { + static_assert(D%QK8_0 == 0); + GGML_ASSERT(nk%8 == 0); + constexpr int nblock = D/QK8_0; + std::vector result(nblock * nk/8); + auto y = result.data(); + repack(nk, q8.data, q8.stride, y); return result; } @@ -15952,12 +16327,13 @@ struct FlashMS { } return F16::reduce_max(vk); } - static inline __m256 apply_mask(int l, const char * mask, __m256 val, __m256 vinf) { - auto m128 = _mm_loadu_si128((const __m128i *)mask+l); - m128 = _mm_cmpeq_epi16(m128, _mm_setzero_si128()); - auto m256 = _mm256_cvtepi16_epi32(m128); - auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16))); - return _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf)); + static inline __m256 apply_mask(int l, const char * mask, __m256 val, [[maybe_unused]] __m256 vinf) { + return _mm256_add_ps(val, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)mask+l))); + //auto m128 = _mm_loadu_si128((const __m128i *)mask+l); + //m128 = _mm_cmpeq_epi16(m128, _mm_setzero_si128()); + //auto m256 = _mm256_cvtepi16_epi32(m128); + //auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16))); + //return _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf)); } #ifdef __AVX512F__ static inline __m512 apply_mask(int l, const char * mask, __m512 val, __m512 vinf) { @@ -16087,7 +16463,6 @@ struct FlashQKV { accumulate_qkv_1(vh, fms); return; } - F16::Data v[8]; for (int j = 0; j < q_step; ++j) { auto R = qkv_cache + D*j; if (fms.need_scaling[j] == 2) { @@ -16100,6 +16475,43 @@ struct FlashQKV { } } } +#ifdef __AVX512F__ + if constexpr ((D/F16::block_size)%4 == 0) { + F16::Data v[16]; + F16::Data vs[4]; + for (int i = 0; i < D/F16::block_size; i += 4) { + for (int l = 0; l < k_step; l += 4) { + for (int k = 0; k < 4; ++k) { + vh.load(l+k, i+0, v[4*k+0], v[4*k+1]); + vh.load(l+k, i+2, v[4*k+2], v[4*k+3]); + } + for (int j = 0; j < q_step; ++j) { + auto R = qkv_cache + D*j; + auto s1 = F16::load(R + F16::block_size*(i+0)); + auto s2 = F16::load(R + F16::block_size*(i+1)); + auto s3 = F16::load(R + F16::block_size*(i+2)); + auto s4 = F16::load(R + F16::block_size*(i+3)); + F16::set4(fms.cache + k_step*j + l, vs); + for (int k = 0; k < 4; ++k) { + s1 = F16::fmadd(s1, v[4*k+0], vs[k]); + s2 = F16::fmadd(s2, v[4*k+1], vs[k]); + s3 = F16::fmadd(s3, v[4*k+2], vs[k]); + s4 = F16::fmadd(s4, v[4*k+3], vs[k]); + } + F16::store(R + F16::block_size*(i+0), s1); + F16::store(R + F16::block_size*(i+1), s2); + F16::store(R + F16::block_size*(i+2), s3); + F16::store(R + F16::block_size*(i+3), s4); + } + } + } + return; + } +#endif + F16::Data v[8]; +#ifdef __AVX2__ + F16::Data vs[4]; +#endif for (int i = 0; i < D/F16::block_size; i += 2) { for (int l = 0; l < k_step; l += 4) { vh.load(l+0, i, v[0], v[4]); @@ -16110,6 +16522,13 @@ struct FlashQKV { auto R = qkv_cache + D*j; auto s1 = F16::load(R + F16::block_size*(i+0)); auto s2 = F16::load(R + F16::block_size*(i+1)); +#ifdef __AVX2__ + F16::set4(fms.cache + k_step*j + l, vs); + for (int k = 0; k < 4; ++k) { + s1 = F16::fmadd(s1, v[k+0], vs[k]); + s2 = F16::fmadd(s2, v[k+4], vs[k]); + } +#else auto vs = F16::set4(fms.cache + k_step*j + l); s1 = F16::fmadd_lane0(s1, v[0], vs); s2 = F16::fmadd_lane0(s2, v[4], vs); @@ -16119,6 +16538,7 @@ struct FlashQKV { s2 = F16::fmadd_lane2(s2, v[6], vs); s1 = F16::fmadd_lane3(s1, v[3], vs); s2 = F16::fmadd_lane3(s2, v[7], vs); +#endif F16::store(R + F16::block_size*(i+0), s1); F16::store(R + F16::block_size*(i+1), s2); } @@ -16239,7 +16659,8 @@ struct FlashQKV { // As a result, we get an infinite stream of warnings about uninitialized variable use (one for each // combination of D, q_step, k_step), which is extremely annoying. Hence, I succumb to the trend of // constantly being saved by others (the compiler in this case), and add this 100% unnecessary initialization. - qkv_cache_t qkv_cache[D*q_step] = {}; + qkv_cache_t qkv_cache[D*q_step]; // = {}; + //qkv_cache_t * qkv_cache; }; template @@ -16481,8 +16902,14 @@ struct FlashQKfp32 { MAKE_FUNCS(mul_mat_qX_0_q8_0, 1); + if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 2); + if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 4); MAKE_FUNCS(mul_mat_qX_1_q8_2_T, 1); + if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_0_Tx, 2); + if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_0_Tx, 4); MAKE_FUNCS(mul_mat_qX_0_q8_0_T, 16); -#endif if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1<1>, 1); +#ifdef HAVE_FANCY_SIMD + if constexpr (D%32 == 0 && k_step%8 == 0) { + if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV_8<16>, 16); + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV_8, nq); + } else { + if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16); + } +#endif MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq); #endif } @@ -16514,17 +16946,23 @@ struct FlashQKfp32 { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_0_q8_0, 1); + if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 2); + if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 4); MAKE_FUNCS(mul_mat_qX_1_q8_2_T>) { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_0_q8_0, 1); + if (nq == 2) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 2); + if (nq == 4) return std::make_pair(mul_mat_qX_0_q8_2_Tx, 4); MAKE_FUNCS(mul_mat_qX_1_q8_2_T>) { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_1_q8_1& fms, FlashQKV& fqkv, const float * q, const char * mask, float * qkv, - float * M, float * S) { - typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)]; + float * M, float * S, char * qptr) { + auto q8 = (typename KHelper::block_q8 *)qptr; + if constexpr (q_step > 1 && std::is_same_v>) { + if (nq1 == q_step) { + fms.init_qstep(); + kh.reset_block(); + vh.reset_block(); + block_q8_0_r8 q8r8[Dk/QK8_0 * k_step/8]; + HelperQ80R8 khr8((const char *)q8r8, Dk/QK8_0*sizeof(block_q8_0)); + HelperQ80::convert(q_step, stride_q, q, q8); + auto mr = mask; + for (int k1 = 0; k1 < nk1/k_step; ++k1) { + HelperQ80R8::repack(k_step, kh.data, kh.stride, q8r8); + KQHelper::mul_mask_kq(khr8, stride_m, q8, mr, fms); + fqkv.accumulate_qkv(vh, fms); + kh.next_block(); + vh.next_block(); + mr += k_step*sizeof(ggml_half); + } + fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); + return; + } + } #if FA_TIMING Perf perf(false); #endif @@ -16731,6 +17190,12 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, #endif } +char * get_q_storage(size_t size) { + thread_local std::vector q_storage; + if (q_storage.size() < size) q_storage.resize(size); + return q_storage.data(); +} + // Some of the methods in FlashAttn have two identical implementations that only differ by // one version using a loop over the template parameter q_step, while the other using a loop // over an input parameter nq (these are loops over the rows of q^T). I dislike this a lot, @@ -16753,44 +17218,57 @@ struct FlashAttn { template void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv, [[maybe_unused]] float * M, [[maybe_unused]] float * S) { - if constexpr (std::is_same_v> || std::is_same_v> || + if constexpr (std::is_same_v> || + std::is_same_v> || std::is_same_v> || - std::is_same_v>) { - compute_helper_q>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); - } - else if constexpr (std::is_same_v>) { - if (nq1 >= 8) { + std::is_same_v> || + std::is_same_v> || + std::is_same_v> || + std::is_same_v> || + std::is_same_v>) { + constexpr size_t kMaxOnStackSize = 576; + auto q_size = q_step*(Dk/KHelper::block_size_q)*sizeof(typename KHelper::block_q8); + q_size = GGML_PAD(q_size, 64); + if (q_size > kMaxOnStackSize) { + auto qptr = get_q_storage(q_size); + if (nq1 >= 8) { + if constexpr (std::is_same_v>) { #if FA_TIMING - auto t1 = Perf::cur_time(); - HelperQ80R8 khr4(nk1, kh); - Perf::instance().accum(4, t1); + auto t1 = Perf::cur_time(); + HelperQ80R8 khr4(nk1, kh); + Perf::instance().accum(4, t1); #else - HelperQ80R8 khr4(nk1, kh); + HelperQ80R8 khr4(nk1, kh); #endif - compute_helper_q, VHelper, FlashQKfp32>( - khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); - } else{ + compute_helper_q, VHelper, FlashQKfp32>( + khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); + return; + + } + if constexpr (std::is_same_v>) { +#if FA_TIMING + auto t1 = Perf::cur_time(); + HelperQ8KVR8 khr4(nk1, kh); + Perf::instance().accum(4, t1); +#else + HelperQ8KVR8 khr4(nk1, kh); +#endif + compute_helper_q, VHelper, FlashQKfp32>( + khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); + return; + } + } compute_helper_q>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr); + + } + else { + typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)]; + compute_helper_q>( + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, (char *)q8); } } - else if constexpr (std::is_same_v>) { - if (nq1 >= 8) { -#if FA_TIMING - auto t1 = Perf::cur_time(); - HelperQ8KVR8 khr4(nk1, kh); - Perf::instance().accum(4, t1); -#else - HelperQ8KVR8 khr4(nk1, kh); -#endif - compute_helper_q, VHelper, FlashQKfp32>( - khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); - } else{ - compute_helper_q>( - kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); - } - } else { + else { compute_helper>( kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S); } @@ -17234,39 +17712,61 @@ template inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) { - if (nk1 >= 256) { //4096) { - if (nq1 >= 64) { + auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) { + nq1 -= n; + if (nq1 == 0) return true; + q += n*stride_q; + mask += n*stride_m; + qkv += n*stride_qkv; + if (M && S) { M += n; S += n; } + return false; + }; + if (nk1 >= 512) { + if (nq1 >= 128) { + int n_step = nq1/128; FlashAttn fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); - return; + fa.compute(kh, vh, 128*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(128*n_step)) return; + } + if (nq1 >= 64) { + int n_step = nq1/64; + FlashAttn fa(scale, softcap); + fa.compute(kh, vh, 64*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(64*n_step)) return; } if (nq1 >= 32) { + int n_step = nq1/32; FlashAttn fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); - return; + fa.compute(kh, vh, 32*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(32*n_step)) return; } if (nq1 >= 16) { + int n_step = nq1/16; FlashAttn fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); - return; + fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(16*n_step)) return; } } if (nq1 >= 8) { + int n_step = nq1/8; FlashAttn fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(8*n_step)) return; } else if (nq1 >= 4) { + int n_step = nq1/4; FlashAttn fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(4*n_step)) return; } else if (nq1 >= 2) { + int n_step = nq1/2; FlashAttn fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); - } - else { - FlashAttn fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); + if (update(2*n_step)) return; } + FlashAttn fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S); } #ifdef __AVX512BF16__ @@ -17327,11 +17827,11 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, HelperQ60 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; -#if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { HelperQ40 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); } break; +#if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_1: { HelperQ41 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); @@ -17360,6 +17860,10 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperQ80 kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; + case GGML_TYPE_Q8_0_R8: { + HelperQ80R8 kh(k, stride_k); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); + } break; case GGML_TYPE_Q8_KV: { HelperQ8KV kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); @@ -17368,11 +17872,11 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperQ60 kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; -#if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_0: { HelperQ40 kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); } break; +#if GGML_IQK_FA_ALL_QUANTS case GGML_TYPE_Q4_1: { HelperQ41 kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S); @@ -17393,9 +17897,10 @@ inline bool flash_attn_is_supported(ggml_type type) { #endif #if GGML_IQK_FA_ALL_QUANTS if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || - type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true; + type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL || type == GGML_TYPE_Q8_0_R8) return true; #else - if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_Q8_KV) return true; + if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q6_0 || type == GGML_TYPE_Q8_KV || type == GGML_TYPE_Q8_0_R8 + || type == GGML_TYPE_Q4_0) return true; #endif return false; } @@ -17404,25 +17909,35 @@ template inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) { + auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) { + nq1 -= n; + if (nq1 == 0) return true; + q += n*stride_q; + mask += n*stride_m; + qkv += n*stride_qkv; + if (M && S) { M += n; S += n; } + return false; + }; if (nq1 >= 8) { + int n_step = nq1/8; FlashAttn<576, 512, 8, step_k> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + if (update(8*n_step)) return; } - else if (nq1 >= 4) { + if (nq1 >= 4) { + int n_step = nq1/4; FlashAttn<576, 512, 4, step_k> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + if (update(4*n_step)) return; } - else { - FlashAttn<576, 512, 1, step_k> fa(scale, softcap); - fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + if (nq1 >= 2) { + int n_step = nq1/2; + FlashAttn<576, 512, 2, step_k> fa(scale, softcap); + fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); + if (update(2*n_step)) return; } - //if (nq1 % 8 == 0) { - // FlashAttn<576, 512, 8, step_k> fa(scale, softcap); - // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); - //} else { - // FlashAttn<576, 512, 1, step_k> fa(scale, softcap); - // fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); - //} + FlashAttn<576, 512, 1, step_k> fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S); } template @@ -17436,6 +17951,12 @@ inline bool iqk_deepseek_helper(ggml_type type_k, iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); return true; } + if (type_k == GGML_TYPE_Q8_0_R8) { + HelperQ80R8<576, step_k> kh((const char *)k, stride_k); + HelperQ80<512, step_k> vh((const char *)v, stride_v); + iqk_deepseek_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S); + return true; + } if (type_k == GGML_TYPE_Q6_0) { HelperQ60<576, step_k> kh((const char *)k, stride_k); HelperQ60<512, step_k> vh((const char *)v, stride_v); @@ -17558,6 +18079,23 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k } #endif + if (nk1%128 == 0) { + switch (Dk) { + case 64: + iqk_flash_helper_T< 64, 64, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + case 96: + iqk_flash_helper_T< 96, 96, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + case 128: + iqk_flash_helper_T<128, 128, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + case 192: + iqk_flash_helper_T<192, 128, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + case 256: + iqk_flash_helper_T<256, 256, 128>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv, M, S); break; + default: + return false; + } + return true; + } if (nk1%64 == 0) { switch (Dk) { case 64: From 99b87a375fdf8cc409c4a95cf451f0462f56f71b Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 29 Apr 2025 07:22:06 +0200 Subject: [PATCH 103/132] Update AUTHORS Add @ubergarm --- AUTHORS | 1 + 1 file changed, 1 insertion(+) diff --git a/AUTHORS b/AUTHORS index 035292f1..818feaf6 100644 --- a/AUTHORS +++ b/AUTHORS @@ -3,3 +3,4 @@ saood06 Nexes the Elder <124105151+Nexesenex@users.noreply.github.com> fairydreaming <166155368+fairydreaming@users.noreply.github.com> Stanisław Szymczyk +ubergarm From 1064f5bc312f61e5a1b7ef3fef918be300f74641 Mon Sep 17 00:00:00 2001 From: Ben Harris Date: Tue, 29 Apr 2025 16:02:08 +0800 Subject: [PATCH 104/132] Apply Qwen3 PR from llama.cpp (#355) --- convert_hf_to_gguf.py | 7 + gguf-py/gguf/constants.py | 36 ++++ src/llama.cpp | 391 +++++++++++++++++++++++++++++++++++++- 3 files changed, 433 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index a6ab09c0..20d27a5c 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1938,6 +1938,13 @@ class Qwen2MoeModel(Model): if len(experts) > 0: raise ValueError(f"Unprocessed experts: {experts}") +@Model.register("Qwen3ForCausalLM") +class Qwen3Model(Qwen2Model): + model_arch = gguf.MODEL_ARCH.QWEN3 + +@Model.register("Qwen3MoeForCausalLM") +class Qwen3MoeModel(Qwen2MoeModel): + model_arch = gguf.MODEL_ARCH.QWEN3MOE @Model.register("GPT2LMHeadModel") class GPT2Model(Model): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 6e03f9a4..4f0681b4 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -331,6 +331,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.QWEN: "qwen", MODEL_ARCH.QWEN2: "qwen2", MODEL_ARCH.QWEN2MOE: "qwen2moe", + MODEL_ARCH.QWEN3: "qwen3", + MODEL_ARCH.QWEN3MOE: "qwen3moe", MODEL_ARCH.PHI2: "phi2", MODEL_ARCH.PHI3: "phi3", MODEL_ARCH.PLAMO: "plamo", @@ -701,6 +703,40 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, ], + MODEL_ARCH.QWEN3: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.QWEN3MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], MODEL_ARCH.PLAMO: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/src/llama.cpp b/src/llama.cpp index e5988fd4..939e6e4b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -202,6 +202,8 @@ enum llm_arch { LLM_ARCH_QWEN, LLM_ARCH_QWEN2, LLM_ARCH_QWEN2MOE, + LLM_ARCH_QWEN3, + LLM_ARCH_QWEN3MOE, LLM_ARCH_PHI2, LLM_ARCH_PHI3, LLM_ARCH_PLAMO, @@ -228,7 +230,7 @@ enum llm_arch { LLM_ARCH_T5, LLM_ARCH_T5ENCODER, LLM_ARCH_JAIS, - LLM_ARCH_GRANITE = 46, + LLM_ARCH_GRANITE, LLM_ARCH_GRANITE_MOE, LLM_ARCH_COHERE2, LLM_ARCH_UNKNOWN, @@ -254,6 +256,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN, "qwen" }, { LLM_ARCH_QWEN2, "qwen2" }, { LLM_ARCH_QWEN2MOE, "qwen2moe" }, + { LLM_ARCH_QWEN3, "qwen3" }, + { LLM_ARCH_QWEN3MOE, "qwen3moe" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PLAMO, "plamo" }, @@ -940,6 +944,45 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, }, }, + { + LLM_ARCH_QWEN3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_QWEN3MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_PHI2, { @@ -5430,6 +5473,22 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_QWEN3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN3MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_PHI2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -6571,6 +6630,10 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } + if (model.arch == LLM_ARCH_QWEN3MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + } + if (model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) { LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); @@ -7474,6 +7537,83 @@ static bool llm_load_tensors( layer.ffn_up_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}); } } break; + case LLM_ARCH_QWEN3: + { + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); + + layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}); + layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}); + + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } + } break; + case LLM_ARCH_QWEN3MOE: + { + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); + layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); + + layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}); + layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}); + + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + + GGML_ASSERT(n_expert > 0); + GGML_ASSERT(n_expert_used > 0); + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}); + layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + } + } break; case LLM_ARCH_PHI2: { model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -12093,6 +12233,245 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_qwen3() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_qwen3moe() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = + llm_build_moe_ffn(ctx0, lctx, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLM_EXPERT_GATING_FUNC_SOFTMAX, + cb, il); + cb(cur, "ffn_moe_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + struct ggml_cgraph * build_phi2() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -16116,6 +16495,14 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_qwen2moe(); } break; + case LLM_ARCH_QWEN3: + { + result = llm.build_qwen3(); + } break; + case LLM_ARCH_QWEN3MOE: + { + result = llm.build_qwen3moe(); + } break; case LLM_ARCH_PHI2: { result = llm.build_phi2(); @@ -19909,6 +20296,8 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_QWEN: case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2MOE: + case LLM_ARCH_QWEN3: + case LLM_ARCH_QWEN3MOE: case LLM_ARCH_PHI2: case LLM_ARCH_PHI3: case LLM_ARCH_GEMMA: From 9ba362706c998902752caf31d99fe077ed7d4faa Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 29 Apr 2025 10:05:38 +0200 Subject: [PATCH 105/132] Add missing enum values for qwen3 and qwen3moe (#356) Co-authored-by: Iwan Kawrakow --- gguf-py/gguf/constants.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 4f0681b4..d6feb9ea 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -198,6 +198,8 @@ class MODEL_ARCH(IntEnum): QWEN = auto() QWEN2 = auto() QWEN2MOE = auto() + QWEN3 = auto() + QWEN3MOE = auto() PHI2 = auto() PHI3 = auto() PLAMO = auto() From 4c2bee0bedbf8f46a4d36ba66508d872a47ee28c Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 30 Apr 2025 10:45:43 +0200 Subject: [PATCH 106/132] Fix IQK_FA_ALL_QUANTS on AVX2 (#360) * Fix IQK_FA_ALL_QUANTS on AVX2 * Make it also work, not just compile --------- Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_mul_mat.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 5f916584..5e49089a 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -16076,8 +16076,13 @@ struct HelperIQ4nl final : public BaseHelper { constexpr static int block_size_q = QK8_0; #else HelperIQ4nl(const char * data, int stride) : Base(data, stride) {} +#ifdef HAVE_FANCY_SIMD using block_q8 = block_q8_2; constexpr static int block_size_q = QK8_2; +#else + using block_q8 = block_q8_0; + constexpr static int block_size_q = QK8_0; +#endif #endif // Needed for v * softmax(k * q) @@ -16974,7 +16979,11 @@ struct FlashQKfp32 { #ifdef __aarch64__ MAKE_FUNCS(mul_mat_qX_0_q8_0 Date: Wed, 30 Apr 2025 15:11:29 +0200 Subject: [PATCH 107/132] Update README.md (#352) * Update README.md * Edits * Updates --------- Co-authored-by: Iwan Kawrakow --- README.md | 305 +++++++++--------------------------------------------- 1 file changed, 49 insertions(+), 256 deletions(-) diff --git a/README.md b/README.md index 207d6de2..a04c1130 100644 --- a/README.md +++ b/README.md @@ -1,269 +1,62 @@ -# llama.cpp clone with better CPU performance +# ik_llama.cpp: llama.cpp fork with better CPU performance [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) ## TL;DR -This repository is a clone of [llama.cpp](https://github.com/ggerganov/llama.cpp) with the following improvements -* Better implementation of CPU matrix multiplications (`AVX2` and `ARM_NEON`) for `fp16/fp32` and all k-, i-, and legacy `llama.cpp` quants, that leads to a significant improvement in prompt processing (PP) speed, typically in the range of 2X, but up to 4X for some quantization types. Token generation (TG) also benefits, but to a lesser extent due to TG being memory bound -* Faster CPU inference for MoE models with similar performance gains -* Implementation of the [Bitnet b1.58](https://huggingface.co/1bitLLM/bitnet_b1_58-3B) model for the CPU (`AVX2` and `ARM_NEON`) and GPU (`CUDA` and `Metal`). This implementation is much faster than the unmerged `llama.cpp` [PR-8151](https://github.com/ggerganov/llama.cpp/pull/8151) +This repository is a fork of [llama.cpp](https://github.com/ggerganov/llama.cpp) with better CPU and hybrid GPU/CPU performance, new SOTA quantization types, first-class Bitnet support, better DeepSeek performance via MLA, FlashMLA, fused MoE operations and tensor overrides for hybrid GPU/CPU inference, row-interleaved quant packing, etc. -If you are not already familiar with [llama.cpp](https://github.com/ggerganov/llama.cpp), it is better to start there. For those familiar with `llama.cpp`, everything here works the same as in `llama.cpp` (or at least the way `llama.cpp` worked when I last synced on Aug 12 2024). +## Latest News -Note that I have published some, but not all, of the code in this repository in a series of [llamafile](https://github.com/Mozilla-Ocho/llamafile) PRs ([394](https://github.com/Mozilla-Ocho/llamafile/pull/394), [405](https://github.com/Mozilla-Ocho/llamafile/pull/405), [428](https://github.com/Mozilla-Ocho/llamafile/pull/428), [435](https://github.com/Mozilla-Ocho/llamafile/pull/435), [453](https://github.com/Mozilla-Ocho/llamafile/pull/453), and [464](https://github.com/Mozilla-Ocho/llamafile/pull/464)) +* April 29 2025: Qwen3 support added +* April 26 2025: GLM-4 support added +* April 26 2025: Command-A support added +* April 22 2025: Support for the latest Microsoft Bitnet model added +* April 21 2025: ik_llama.cpp builds and runs successfully on Android (using termux) +* April 17 2025: Better CPU Flash Attention token generation performance +* April 13 2025: `IQ1_M` quantization improvements +* April 10 2025: LLaMA-4 support added +* April 7 2025: `IQ2_XS` quantization improvements +* April 3 2025: Much faster MoE implementation on Metal +* April 1 2025: Quantization improvements for `Q2_K, Q4_K, Q5_K, Q4_1, Q5_1` +* March 28 2025: Quantization imrovements for `Q4_0, Q5_0, Q6_0, Q3_K, Q6_K, IQ4_XS, IQ4_NL` +* March 25 2025: Better MoE performance on CUDA +* March 23 2025: Better batched processing speed for DeepSeek models +* March 22 2025: Gemma3 support added +* March 21 2025: FlashMLA-3: fastest CPU-only inference for DeepSeek models +* March 18 2025: reduce compute buffer size +* March 17 2025: FlashMLA-2 performance improvements +* March 12 2025: Allow `Q8_0` KV cache with FlashMLA-2 on CUDA +* March 10 2025: Better TG performance for MoE models on CUDA +* March 9 2025: FlashMLA on CUDA +* March 8 2025: Faster FlashMLA CPU implementation +* March 7 2025: Custom quantization mixes using regular expressions +* March 5 2025: FlashMLA on CUDA +* March 3 2025: Introducing FlashMLA - MLA with Flash Attention +* March 1 2025: Smart Expert Reduction for faster DeepSeek inference +* Feb 27 2025: MLA without transposed cache +* Feb 25 2025: tensor overrides for better control where model weights are stored (GPU or CPU) +* Feb 23 2025: fused FFN ops for faster MoE inference +* Feb 23 2025: `sweep-bench` - better performance benchmarking +* Feb 20 2025: fast GEMM/GEMV for `IQ1_S` +* Feb 19 2025: `Q8_KV` - new type for 8-bit KV-cache quantization +* Feb 13 2025: allow `Q8_0` quantized cache with MLA +* Feb 11 2025: Flash Attention support for DeepSeek models +* Feb 9 2025: MLA for DeepSeek models +* Jan 23 2025: DeepSeek-V3 support added -The implementation of matrix-matrix and matrix-vector multiplications is in a single C++ source file (`iqk_mul_mat.cpp`) with just two interface functions `iqk_mul_mat` (`fp16/fp32` and quantized matrix multiplications) and `iqk_mul_mat_moe` (as `iqk_mul_mat` but meant to be used for the FFN part of a MoE model). Under the hood `iqk_mul_mat_moe` uses the same implementation as `iqk_mul_mat`, with the only difference being where results are stored in memory. Bitnet quantization related stuff is in `iqk-quantize.cpp`. +## Resources -## Why? +There is no single point of reference describing all new `ik_llama.cpp` features. Pull requests often contain detailed information, so browsing the PRs is often the best way to learn about new features and how to use them. In addition +* [The Wiki page](https://github.com/ikawrakow/ik_llama.cpp/wiki) has performance comparisons to mainline `llama.cpp` +* [This guide](https://github.com/ikawrakow/ik_llama.cpp/discussions/258) is a good place to start if you came here because of DeepSeek models +* [This discussion](https://github.com/ikawrakow/ik_llama.cpp/discussions/266) is about running DeepSeek-V3/R1 on a 16 x 3090 setup +* [This discussion](https://github.com/ikawrakow/ik_llama.cpp/discussions/8) describes the new quantization types available in `ik_llama.cpp` -Mostly out of curiosity: -* Justine Tunney's `tinyBLAS`, which she contributed to `llama.cpp` in [PR 6414](https://github.com/ggerganov/llama.cpp/pull/6414), only works for `Q4_0`, `Q8_0` and `fp16/bf16` models. In the surrounding discussion about possibly extending `tinyBLAS` to k- and i-quants, she felt that k-quants are [not amenable to block-tiling](https://github.com/ggerganov/llama.cpp/pull/6840#issuecomment-2072995387), which is required to improve performance. This statement piqued my curiosity, so here we are. -* Bitnet-1.58b has been one of the [most discussed topics](https://github.com/ggerganov/llama.cpp/issues/5761#issuecomment-2198380366) in the `llama.cpp` project, so eventually I decided to see how efficiently one can implement a ternary model +## Contributing -Curiosity aside, improved CPU performance may be (or may become) important in practice. According to The Register, 70% of AI inference [is done on the CPU of mobile phones](https://www.theregister.com/2024/05/30/arm_cortex_x925_ai_cores/?td=rt-3a), at least in the Android world (but I haven't come around to actually comparing performance on a phone). With ever increasing number of LLM model parameters, and with Meta's 400B model just released, the CPU may become the only viable option for people not willing (or not able to) rent/buy uber expensive GPU instances capable of running such models. Granted, one would need a pretty beefy computer to run a 400B model, and inference speed will be sluggish, but at least one will not need to spend the equivalent of a luxury apartment in the downtown of the city where I live to buy the GPU system capable of running the model. +Contributions in form of pull requests, issue submissions (bug reports, feature requests), or general discussions, are welcome. -## Performance comparison to llama.cpp - -The results in the following tables are obtained with these parameters: -* Model is LLaMA-v3-8B for `AVX2` and LLaMA-v2-7B for `ARM_NEON` -* The `AVX2` CPU is a 16-core Ryzen-7950X -* The `ARM_NEON` CPU is M2-Max -* `tinyBLAS` is enabled in `llama.cpp` -* `llama.cpp` results are for `build: 081fe431 (3441)`, which was the current `llama.cpp` master branch when I pulled on July 23 2024. -* The projects are built without `CUDA` support, no `BLAS`, and Accelerate framework disabled - -### Prompt processing - -Here I set the number of threads to be equal to the number of (performance) cores of the CPU, so 16 threads for the Ryzen-7950X and 8 threads for the M2-Max. The following table summarizes the results. To not make the table too long, I have listed only quantized models containing predominantly one quantization type (i.e., excluded the `QX_K - Medium/Large` variants, which are typically a mix of `QX_K` and `Q(X+1)_K`, as well as `IQ2_S` and `IQ3_XS`). - -The command line to generate the benchmark data is -``` -./bin/llama-bench -m $model -p 512 -n 0 -t $num_threads -ngl 0 -``` - -| Quantization| size | backend | threads | t/s (llama.cpp) | t/s (iqk_mul_mat)| Speedup | -| ----------- | ---------: | ---------- | ------: | ---------------: | ---------------: | ------: | -| 8B F16 | 14.96 GiB | AVX2 | 16 | 112.37 ± 0.40 | 131.27 ± 0.38 | 1.168 | -| 7B F16 | 12.55 GiB | NEON | 8 | 90.28 ± 1.25 | 95.34 ± 0.15 | 1.056 | -| 8B Q8_0 | 7.95 GiB | AVX2 | 16 | 118.07 ± 0.53 | 134.00 ± 0.47 | 1.135 | -| 7B Q8_0 | 6.67 GiB | NEON | 8 | 77.25 ± 1.81 | 94.14 ± 1.15 | 1.219 | -| 8B Q4_0 | 4.35 GiB | AVX2 | 16 | 104.46 ± 0.33 | 130.20 ± 0.29 | 1.246 | -| 7B Q4_0 | 3.57 GiB | NEON | 8 | 65.46 ± 0.79 | 76.22 ± 0.71 | 1.164 | -| 8B Q4_1 | 4.77 GiB | AVX2 | 16 | 57.83 ± 0.24 | 160.69 ± 0.49 | 2.779 | -| 7B Q4_1 | 3.95 GiB | NEON | 8 | 37.40 ± 0.50 | 65.83 ± 0.98 | 1.760 | -| 8B Q5_0 | 5.22 GiB | AVX2 | 16 | 53.50 ± 0.35 | 122.62 ± 0.48 | 2.292 | -| 7B Q5_0 | 4.34 GiB | NEON | 8 | 29.31 ± 0.51 | 67.51 ± 1.17 | 2.303 | -| 8B Q5_1 | 5.64 GiB | AVX2 | 16 | 50.85 ± 0.36 | 147.15 ± 0.47 | 2.894 | -| 7B Q5_1 | 4.72 GiB | NEON | 8 | 26.02 ± 0.37 | 58.49 ± 0.85 | 2.248 | -| 8B Q2_K_S | 2.78 GiB | AVX2 | 16 | 110.11 ± 0.28 | 192.47 ± 1.35 | 1.748 | -| 7B Q2_K_S | 2.16 GiB | NEON | 8 | 35.44 ± 0.06 | 77.93 ± 1.64 | 2.199 | -| 8B Q3_K_S | 3.41 GiB | AVX2 | 16 | 77.42 ± 0.36 | 181.64 ± 0.44 | 2.346 | -| 7B Q3_K_S | 2.75 GiB | NEON | 8 | 26.79 ± 0.03 | 59.38 ± 1.08 | 2.216 | -| 8B Q4_K_S | 4.36 GiB | AVX2 | 16 | 98.92 ± 0.34 | 185.35 ± 0.39 | 1.874 | -| 7B Q4_K_S | 3.59 GiB | NEON | 8 | 46.55 ± 0.67 | 76.31 ± 0.38 | 1.639 | -| 8B Q5_K_S | 5.21 GiB | AVX2 | 16 | 69.44 ± 0.31 | 179.62 ± 0.69 | 2.587 | -| 7B Q5_K_S | 4.33 GiB | NEON | 8 | 30.18 ± 0.23 | 65.34 ± 0.79 | 2.165 | -| 8B Q6_K | 6.14 GiB | AVX2 | 16 | 74.89 ± 0.26 | 181.86 ± 0.55 | 2.428 | -| 7B Q6_K | 5.15 GiB | NEON | 8 | 28.12 ± 1.24 | 60.75 ± 1.15 | 2.160 | -| 8B IQ2_XXS | 2.23 GiB | AVX2 | 16 | 42.57 ± 0.16 | 126.63 ± 0.55 | 2.975 | -| 7B IQ2_XXS | 1.73 GiB | NEON | 8 | 20.87 ± 0.20 | 64.29 ± 1.12 | 3.080 | -| 8B IQ2_XS | 2.42 GiB | AVX2 | 16 | 46.45 ± 0.27 | 125.46 ± 0.43 | 2.701 | -| 7B IQ2_XS | 1.89 GiB | NEON | 8 | 22.77 ± 0.21 | 51.15 ± 0.24 | 2.246 | -| 8B IQ2_M | 2.74 GiB | AVX2 | 16 | 40.76 ± 0.18 | 113.07 ± 0.48 | 2.774 | -| 7B IQ2_M | 2.20 GiB | NEON | 8 | 14.95 ± 0.26 | 44.87 ± 0.50 | 3.001 | -| 8B IQ3_XXS | 3.04 GiB | AVX2 | 16 | 31.95 ± 0.20 | 109.86 ± 0.45 | 3.438 | -| 7B IQ3_XXS | 2.41 GiB | NEON | 8 | 14.40 ± 0.10 | 53.58 ± 0.85 | 3.721 | -| 8B IQ3_S | 3.42 GiB | AVX2 | 16 | 28.04 ± 0.08 | 96.28 ± 0.45 | 3.434 | -| 7B IQ3_S | 2.75 GiB | NEON | 8 | 12.08 ± 0.30 | 49.72 ± 0.06 | 4.116 | -| 8B IQ4_XS | 4.13 GiB | AVX2 | 16 | 68.98 ± 0.31 | 180.34 ± 0.55 | 2.614 | -| 7B IQ4_XS | 3.37 GiB | NEON | 8 | 40.67 ± 1.97 | 75.11 ± 1.97 | 1.847 | -| 8B IQ4_NL | 4.35 GiB | AVX2 | 16 | 59.94 ± 0.21 | 129.06 ± 0.43 | 2.153 | -| 7B IQ4_NL | 3.56 GiB | NEON | 8 | 34.36 ± 0.81 | 76.02 ± 1.36 | 2.212 | - -We see that `llama.cpp` achieves respectable performance for `fp16`, `Q8_0`, and `Q4_0`, being only up to 25% slower than this implementation. This is thanks to the use of Justine Tunney's `tinyBLAS`, which is utilized for these quantization types. For all other quants we observe performance gains in the `1.75X - 4X` range, which is not a small feat considering that the `ggml` matrix multiplication functions has been rewritten several times since `llama.cpp` was first published. Performance gains are larger for i-quants due to the higher quant unpacking cost (see discussion in "To tile or not to tile") - -### Token generation - -On the Ryzen-7950X TG is memory bound, and for many quantization types peak performance is achieved at just 4 threads. Hence, only results for 2 and 4 threads are shown for `AVX2`. The M2-Max has a much more capable memory subsystem and as a result performance keep increasing up to 8 threads. Thus, results are given for up to 8 threads for `ARM_NEON`. - -The command line to generate the data was -``` -./bin/llama-bench -m $model -p 0 -n 128 -t $num_threads -ngl 0 -``` - -| Quantization| size | backend | threads | t/s (llama.cpp) | t/s (iqk_mul_mat)| Speedup | -| ---------- | ---------: | ---------- | ------: | ---------------: | ---------------: | ------: | -| 8B F16 | 14.96 GiB | AVX2 | 1 | 2.20 ± 0.00 | 2.25 ± 0.00 | 1.023 | -| | | | 2 | 3.63 ± 0.00 | 3.68 ± 0.00 | 1.014 | -| | | | 4 | 4.20 ± 0.00 | 4.20 ± 0.00 | 1.000 | -| 7B F16 | 12.55 GiB | NEON | 2 | 6.94 ± 0.27 | 7.40 ± 0.01 | 1.066 | -| | | | 4 | 8.73 ± 0.01 | 8.83 ± 0.01 | 1.011 | -| | | | 6 | 9.05 ± 0.02 | 9.05 ± 0.01 | 1.000 | -| 8B Q8_0 | 7.95 GiB | AVX2 | 2 | 5.03 ± 0.00 | 7.87 ± 0.00 | 1.565 | -| | | | 4 | 7.40 ± 0.00 | 7.82 ± 0.00 | 1.057 | -| 7B Q8_0 | 6.67 GiB | NEON | 2 | 8.29 ± 0.44 | 12.07 ± 0.10 | 1.456 | -| | | | 4 | 13.53 ± 0.03 | 15.77 ± 0.08 | 1.166 | -| | | | 8 | 16.24 ± 0.10 | 16.94 ± 0.04 | 1.043 | -| 8B Q4_0 | 4.35 GiB | AVX2 | 2 | 6.36 ± 0.00 | 10.28 ± 0.00 | 1.616 | -| | | | 4 | 10.97 ± 0.06 | 13.55 ± 0.07 | 1.235 | -| 7B Q4_0 | 3.57 GiB | NEON | 2 | 9.77 ± 0.02 | 13.69 ± 0.03 | 1.401 | -| | | | 4 | 17.82 ± 0.06 | 23.98 ± 0.11 | 1.346 | -| | | | 8 | 26.63 ± 0.41 | 29.86 ± 0.04 | 1.121 | -| 8B Q4_1 | 4.77 GiB | AVX2 | 2 | 5.11 ± 0.00 | 11.45 ± 0.00 | 2.241 | -| | | | 4 | 9.08 ± 0.02 | 12.58 ± 0.00 | 1.385 | -| 7B Q4_1 | 3.95 GiB | NEON | 2 | 9.11 ± 0.06 | 14.62 ± 0.04 | 1.605 | -| | | | 4 | 17.04 ± 0.09 | 24.08 ± 0.28 | 1.413 | -| | | | 8 | 25.26 ± 0.24 | 27.23 ± 0.14 | 1.078 | -| 8B Q5_0 | 5.22 GiB | AVX2 | 2 | 5.31 ± 0.01 | 8.30 ± 0.01 | 1.563 | -| | | | 4 | 9.40 ± 0.01 | 11.47 ± 0.00 | 1.220 | -| 7B Q5_0 | 4.34 GiB | NEON | 2 | 7.26 ± 0.06 | 7.52 ± 0.00 | 1.036 | -| | | | 4 | 13.63 ± 0.18 | 14.16 ± 0.10 | 1.039 | -| | | | 8 | 22.55 ± 0.35 | 24.34 ± 0.22 | 1.079 | -| 8B Q5_1 | 5.64 GiB | AVX2 | 2 | 4.52 ± 0.00 | 8.86 ± 0.00 | 1.960 | -| | | | 4 | 7.72 ± 0.05 | 10.68 ± 0.03 | 1.383 | -| 7B Q5_1 | 4.72 GiB | NEON | 2 | 6.51 ± 0.01 | 6.42 ± 0.03 | 0.986 | -| | | | 4 | 12.26 ± 0.18 | 12.21 ± 0.14 | 0.996 | -| | | | 8 | 20.33 ± 0.52 | 21.85 ± 0.22 | 1.075 | -| 8B Q2_K_S | 2.78 GiB | AVX2 | 2 | 11.30 ± 0.00 | 13.06 ± 0.01 | 1.156 | -| | | | 4 | 18.70 ± 0.00 | 19.04 ± 0.65 | 1.014 | -| 7B Q2_K_S | 2.16 GiB | NEON | 2 | 8.42 ± 0.05 | 11.97 ± 0.10 | 1.422 | -| | | | 4 | 15.74 ± 0.01 | 22.09 ± 0.08 | 1.403 | -| | | | 8 | 27.35 ± 0.05 | 38.32 ± 0.05 | 1.401 | -| 8B Q3_K_S | 3.41 GiB | AVX2 | 2 | 8.58 ± 0.00 | 10.82 ± 0.00 | 1.261 | -| | | | 4 | 15.26 ± 0.01 | 16.25 ± 0.01 | 1.065 | -| 7B Q3_K_S | 2.75 GiB | NEON | 2 | 6.40 ± 0.02 | 9.12 ± 0.09 | 1.425 | -| | | | 4 | 12.17 ± 0.00 | 17.11 ± 0.03 | 1.406 | -| | | | 8 | 22.04 ± 0.08 | 31.39 ± 0.31 | 1.424 | -| 8B Q4_K_S | 4.36 GiB | AVX2 | 2 | 9.61 ± 0.00 | 10.72 ± 0.01 | 1.116 | -| | | | 4 | 13.24 ± 0.31 | 13.28 ± 0.01 | 1.003 | -| 7B Q4_K_S | 3.59 GiB | NEON | 2 | 11.15 ± 0.05 | 12.93 ± 0.09 | 1.160 | -| | | | 4 | 20.24 ± 0.16 | 23.49 ± 0.29 | 1.161 | -| | | | 8 | 25.76 ± 0.07 | 28.31 ± 0.22 | 1.099 | -| 8B Q5_K_S | 5.21 GiB | AVX2 | 2 | 7.45 ± 0.00 | 9.73 ± 0.00 | 1.306 | -| | | | 4 | 11.05 ± 0.33 | 11.43 ± 0.02 | 1.034 | -| 7B Q5_K_S | 4.33 GiB | NEON | 2 | 7.20 ± 0.04 | 8.81 ± 0.04 | 1.224 | -| | | | 4 | 13.62 ± 0.15 | 16.81 ± 0.16 | 1.234 | -| | | | 8 | 20.56 ± 0.19 | 23.96 ± 0.14 | 1.165 | -| 8B Q6_K | 6.14 GiB | AVX2 | 2 | 7.53 ± 0.00 | 9.42 ± 0.00 | 1.251 | -| | | | 4 | 9.74 ± 0.00 | 9.97 ± 0.01 | 1.024 | -| 7B Q6_K | 5.15 GiB | NEON | 2 | 6.85 ± 0.04 | 8.30 ± 0.06 | 1.212 | -| | | | 4 | 13.03 ± 0.05 | 15.47 ± 0.17 | 1.187 | -| | | | 8 | 18.52 ± 0.07 | 20.67 ± 0.08 | 1.116 | -| 8B IQ2_XXS | 2.23 GiB | AVX2 | 2 | 5.33 ± 0.01 | 6.40 ± 0.00 | 1.201 | -| | | | 4 | 10.06 ± 0.03 | 11.76 ± 0.03 | 1.169 | -| 7B IQ2_XXS | 1.73 GiB | NEON | 2 | 5.07 ± 0.04 | 5.22 ± 0.05 | 1.030 | -| | | | 4 | 9.63 ± 0.00 | 9.91 ± 0.07 | 1.029 | -| | | | 8 | 17.40 ± 0.50 | 18.65 ± 0.22 | 1.072 | -| 8B IQ2_XS | 2.42 GiB | AVX2 | 2 | 5.83 ± 0.00 | 6.55 ± 0.00 | 1.123 | -| | | | 4 | 10.88 ± 0.09 | 12.07 ± 0.07 | 1.109 | -| 7B IQ2_XS | 1.89 GiB | NEON | 2 | 5.52 ± 0.01 | 5.60 ± 0.00 | 1.014 | -| | | | 4 | 10.50 ± 0.01 | 11.15 ± 0.00 | 1.062 | -| | | | 8 | 18.19 ± 1.30 | 20.94 ± 0.19 | 1.151 | -| 8B IQ2_M | 2.74 GiB | AVX2 | 2 | 5.12 ± 0.01 | 5.17 ± 0.00 | 1.010 | -| | | | 4 | 9.60 ± 0.28 | 9.68 ± 0.16 | 1.008 | -| 7B IQ2_M | 2.20 GiB | NEON | 2 | 3.73 ± 0.02 | 4.53 ± 0.00 | 1.214 | -| | | | 4 | 7.14 ± 0.05 | 8.70 ± 0.06 | 1.218 | -| | | | 8 | 11.99 ± 0.48 | 16.41 ± 0.05 | 1.369 | -| 8B IQ3_XXS | 3.04 GiB | AVX2 | 2 | 4.06 ± 0.01 | 5.00 ± 0.00 | 1.232 | -| | | | 4 | 7.75 ± 0.02 | 9.13 ± 0.45 | 1.178 | -| 7B IQ3_XXS | 2.41 GiB | NEON | 2 | 3.53 ± 0.00 | 3.82 ± 0.00 | 1.082 | -| | | | 4 | 6.74 ± 0.04 | 7.42 ± 0.07 | 1.103 | -| | | | 8 | 11.96 ± 0.40 | 13.19 ± 0.29 | 1.103 | -| 8B IQ3_S | 3.42 GiB | AVX2 | 2 | 3.62 ± 0.00 | 4.06 ± 0.00 | 1.122 | -| | | | 4 | 6.80 ± 0.01 | 7.62 ± 0.10 | 1.121 | -| 7B IQ3_S | 2.75 GiB | NEON | 2 | 2.96 ± 0.01 | 3.21 ± 0.03 | 1.084 | -| | | | 4 | 5.68 ± 0.01 | 6.25 ± 0.05 | 1.100 | -| | | | 8 | 10.32 ± 0.25 | 11.11 ± 0.37 | 1.077 | -| 8B IQ4_XS | 4.13 GiB | AVX2 | 2 | 8.08 ± 0.00 | 11.35 ± 0.00 | 1.405 | -| | | | 4 | 13.36 ± 0.72 | 14.32 ± 0.24 | 1.072 | -| 7B IQ4_XS | 3.37 GiB | NEON | 2 | 9.87 ± 0.03 | 12.06 ± 0.00 | 1.222 | -| | | | 4 | 17.78 ± 0.23 | 22.06 ± 0.28 | 1.241 | -| | | | 8 | 27.62 ± 0.09 | 29.70 ± 0.39 | 1.075 | -| 8B IQ4_NL | 4.35 GiB | AVX2 | 2 | 5.52 ± 0.00 | 10.26 ± 0.00 | 1.859 | -| | | | 4 | 10.78 ± 0.01 | 13.69 ± 0.08 | 1.270 | -| 7B IQ4_NL | 3.56 GiB | NEON | 2 | 8.32 ± 0.01 | 13.54 ± 0.01 | 1.627 | -| | | | 4 | 15.89 ± 0.00 | 24.28 ± 0.29 | 1.528 | -| | | | 8 | 26.56 ± 0.36 | 29.87 ± 0.08 | 1.125 | - -Here gains are generally lower compared to PP due to TG performance being limited by memory bandwidth. Nevertheless, for some quants/architectures/threads the speedup is quite remarkable (e.g., almost a factor of 2 for `Q5_1` on `AVX2` with 2 threads). - -## MoE models - -There is [PR-6840](https://github.com/ggerganov/llama.cpp/pull/6840) from Justine Tunney in `llama.cpp`, but it has not been merged since April 23, so I'll compare performance to the master branch for Mixtral-8x7B. As Mixtral8x7B quantization is quite a lengthy process, the following table shows data only for `Q4_K_S` (a commonly used k-quant, 4 bit), `Q5_0` (a legacy quant, 5 bit), and `IQ4_XXS` (a 3-bit i-quant) - -| model | size | backend | threads | test | t/s (llama.cpp) | t/s (iqk_mul_mat)| Speedup | -| ------------ | ---------: | ---------- | ------: | -------: | ---------------: | ---------------: | ------: | -| 8x7B Q4_K_S | 48.75 GiB | AVX2 | 16 | pp512 | 54.92 ± 0.23 | 102.94 ± 0.37 | 1.874 | -| | | NEON | 8 | pp512 | 23.54 ± 1.56 | 38.32 ± 0.54 | 1.628 | -| | | AVX2 | 4 | tg128 | 7.80 ± 0.07 | 7.83 ± 0.09 | 1.004 | -| | | NEON | 8 | tg128 | 14.95 ± 0.25 | 15.28 ± 0.24 | 2.022 | -| 8x7B IQ3_XXS | 33.07 GiB | AVX2 | 16 | pp512 | 17.58 ± 0.04 | 68.45 ± 0.22 | 3.894 | -| | | NEON | 8 | pp512 | 7.75 ± 0.04 | 34.67 ± 0.40 | 4.474 | -| | | AVX2 | 4 | tg128 | 4.60 ± 0.01 | 5.45 ± 0.09 | 1.185 | -| | | AVX2 | 8 | tg128 | 8.04 ± 0.65 | 9.83 ± 0.06 | 1.223 | -| | | AVX2 | 16 | tg128 | 10.42 ± 0.01 | 10.57 ± 0.01 | 1.014 | -| | | NEON | 8 | tg128 | 6.19 ± 1.16 | 7.27 ± 0.14 | 1.174 | -| 8x7B Q5_0 | 59.11 GiB | AVX2 | 16 | pp512 | 29.06 ± 0.43 | 62.67 ± 0.32 | 2.157 | -| | | NEON | 8 | pp512 | 15.17 ± 0.51 | 27.36 ± 1.03 | 1.804 | -| | | AVX2 | 4 | tg128 | 5.44 ± 0.10 | 6.81 ± 0.06 | 1.252 | -| | | NEON | 8 | tg128 | 12.03 ± 0.77 | 12.41 ± 1.27 | 1.032 | - - -## Bitnet-1.58B - -Two implementations are provided -* `IQ1_BN` - uses 1.625 bits-per-weight (bpw) -* `IQ2_BN` - uses 2.0 bpw - -`IQ2_BN` is faster for PP (CPU and GPU, although the PP performance difference on CUDA is very minor). `IQ1_BN` can arrive at a higher TG performance on the Ryzen-7950X (given enough threads) because of the smaller model size, but it is always slower on the GPU and on the M2-Max CPU. - -There is the unmerged [PR 8151](https://github.com/ggerganov/llama.cpp/pull/8151) in `llama.cpp` that implements Bitnet-1.58B for the CPU (`AVX` and `ARM_NEON`, no GPU implementation). The following table compares performance between this repo and `PR-8151` in `llama.cpp`. The CUDA results were obtained on an RTX-4080, the Metal results on a 30-core M2-Max GPU. - -| model | size | backend | threads | test | t/s (llama.cpp) | t/s (this repo)| Speedup | -| ----------- | ---------: | ---------- | ------: | -----: | ---------------: | -------------: | ------: | -| 3B - IQ1_BN | 729.64 MiB | AVX2 | 16 | pp512 | 120.61 ± 0.48 | 423.19 ± 1.28 | 3.509 | -| | | NEON | 8 | pp512 | 46.64 ± 0.02 | 205.90 ± 0.88 | 4.415 | -| | | CUDA | 8 | pp512 | - | 10660 ± 170 | - | -| | | Metal | 8 | pp512 | - | 698.25 ± 1.91 | - | -| | | AVX2 | 2 | tg128 | 15.79 ± 0.01 | 22.13 ± 0.02 | 1.402 | -| | | AVX2 | 4 | tg128 | 28.64 ± 1.72 | 40.14 ± 0.04 | 1.402 | -| | | AVX2 | 8 | tg128 | 48.91 ± 0.08 | 61.79 ± 0.09 | 1.263 | -| | | AVX2 | 16 | tg128 | 57.73 ± 0.05 | 60.79 ± 0.05 | 1.053 | -| | | NEON | 2 | tg128 | 11.43 ± 0.04 | 16.87 ± 0.02 | 1.476 | -| | | NEON | 4 | tg128 | 21.11 ± 0.05 | 30.66 ± 0.11 | 1.452 | -| | | NEON | 8 | tg128 | 37.36 ± 0.07 | 55.21 ± 0.16 | 1.478 | -| | | CUDA | 8 | tg128 | - | 301.44 ± 0.12 | - | -| | | Metal | 8 | tg128 | - | 76.70 ± 0.07 | - | -| 3B - IQ2_BN | 873.65 MiB | AVX2 | 16 | pp512 | 151.39 ± 0.35 | 540.82 ± 2.48 | 3.572 | -| | | NEON | 8 | pp512 | 46.54 ± 0.03 | 242.05 ± 0.34 | 5.201 | -| | | CUDA | 8 | pp512 | - | 10800 ± 160 | - | -| | | Metal | 8 | pp512 | - | 723.19 ± 0.53 | - | -| | | AVX2 | 2 | tg128 | 18.93 ± 0.02 | 38.34 ± 0.08 | 2.026 | -| | | AVX2 | 4 | tg128 | 34.54 ± 0.06 | 56.29 ± 0.07 | 1.630 | -| | | AVX2 | 8 | tg128 | 52.97 ± 0.07 | 53.44 ± 0.08 | 1.009 | -| | | AVX2 | 16 | tg128 | 51.84 ± 0.25 | 53.46 ± 0.07 | 1.031 | -| | | NEON | 2 | tg128 | 11.40 ± 0.02 | 32.01 ± 0.27 | 2.808 | -| | | NEON | 4 | tg128 | 20.99 ± 0.00 | 56.45 ± 0.11 | 2.689 | -| | | NEON | 8 | tg128 | 37.28 ± 0.08 | 89.77 ± 0.70 | 2.408 | -| | | CUDA | 8 | tg128 | - | 322.10 ± 0.07 | - | -| | | Metal | 8 | tg128 | - | 110.39 ± 0.13 | - | - -We can make the following observations: -* For prompt processing this Bitnet-1.58b implementation is massively better than PR-8151 in `llama.cpp`, with gains between 3.4X and 5.2X! -* We get `PP-512 = 520 t/s` for the 2.0 bpw variant on the Ryzen-7950X, which costs less than $500. Hey, who needs a GPU? -* For low number of threads (2), this implementation is also much faster than PR-8151 for TG, where speed gains are between 1.4X and 2.8X. As we become memory bound on the Ryzen-7950X, the speed advantage goes away there for sufficiently high number of threads. But on the M2-Max this implementation is 1.4X (1.625 bpw) or 2.4X faster even at 8 threads -* Looking at TG on the M2-Max, the GPU looks a bit like wasted silicon (90 vs 110 t/s for TG-128 and the 2.0 bpw variant). If the GPU transistors had been spent to double the M2 number of CPU cores (and all memory bandwidth is given to the CPU), the CPU would be wiping the floor with the GPU. -* I'm of course kidding with the above. Still, it seems there are massive inefficiencies in the `llama.cpp` Metal implementation that start showing up when matrix multiplications become very fast as is the case here. The difference between CPU and GPU prompt processing speed is typically at least a factor of 7 in favor of the GPU on the M2-Max, but it is only around a factor of 3 here. -* It is worth noting that one needs to offload the token embeddings tensor to the GPU, else performance on CUDA/Metal is significantly lower. Bitnet uses the same tensor for token embeddings and for output. Mainline `llama.cpp` currently puts the token embeddings tensor on the CPU, and this results in running the matrix multiplication with the output tensor on the CPU. This most likely affects other models as well (e.g., Gemma), but I haven't yet looked into this. - -To reproduce these results: -* Clone https://huggingface.co/1bitLLM/bitnet_b1_58-3B -* Run `python3 --outtype f16 path_to_bitnet` to convert to GGUF -* Run `./bin/llama-quantize path_to_bitnet/ggml-model-f16.gguf quantized.gguf [iq1_bn | iq2_bn]`. Note: no imatrix is required (and, if you provide one, it is ignored) -* Caveat: only the 3B Bitnet variant works. The smaller Bitnet models contain tensors with number of columns that are not even a multiple of 32, so basically no `llama.cpp` quant will work for these. - -## To tile or not to tile - -The common wisdom for efficient matrix multiplications is to use block tiling, and this is also used here for `fp16/fp32` matrices. But block tiling does not somehow magically reduce the amount of computation that needs to get done. Performance gains are simply due to the better utilization of memory caches. When dealing with quantized matrix multiplications, there is an additional factor that comes into play: the quantized data needs to be unpacked to 8-bit integers before being used in the matrix multiplication multiply-add operations. Depending on quantization type, this unpacking can represent a significant fraction of the overall computation cost. Hence, for best performance, one would want to reuse the unpacked quants as much as possible, thus spending some fraction of the available vector registers to hold the unpacked data. But when using block tiling, one also needs a certain number of vector registers for accumulating results. For instance, on `AVX2` (16 vector registers available), for `fp16/fp32` models best performance is achieved with `2 x 6` tiles (where the `2` refers to rows in the left matrix and is measured in units of the vector register size, so 16/8 floats for `fp16/fp32`, and `6` is for the number of columns in the right matrix). Unpacking quantized data works best when done in blocks of 128 or 256 quants so that, if we wanted to keep unpacked quants for 2 rows, we would need at least 8 vector registers, thus being left with less than 8 registers for result accumulation, so at best `2 x 3` tiles. In practice one needs addition vector registers for various constants that are typically needed for de-quantization, so that, at the end, it becomes better to use `1 x N` "tiles", i.e., a row-wise multiplication where each row in the left matrix is multiplied with `N` columns in the right matrix, thus reusing the unpacked data `N` times. This (i.e., amortizing de-quantization cost) is the main mechanism for seeding up quantized matrix multiplications. Having started with quantized matrices, and having gone from tiles to a row-wise implementation after some experimentation, I did try row-wise multiplication for float matrices first. Performance was not quite as good as for block-tiling, but I did get up to 90-95% of the speed of `tinyBLAS` that way before switching the `fp16/fp32` implementation to `2 x 6` (`AVX2`) or `5 x 5` (`AVX512` and `ARM_NEON`) block-tiles. But even for for `Q8_0 x Q8_0` multiplications, where there is basically no de-quantization cost, row-wise multiplication is faster than tiling (and hence this implementation beats `tinyBLAS`, which uses block-tiling also for `Q8_0`). +## License +MIT From d37add8b39634a6bc31a2e3bda8bbe24297c66db Mon Sep 17 00:00:00 2001 From: saood06 Date: Fri, 2 May 2025 00:07:24 -0500 Subject: [PATCH 108/132] Fix model architecture name (#366) Co-authored-by: junhuihe --- src/llama.cpp | 40 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 939e6e4b..dd9d82dc 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -227,6 +227,7 @@ enum llm_arch { LLM_ARCH_GLM4, LLM_ARCH_BITNET, LLM_ARCH_BITNET_25, + LLM_ARCH_BITNET_B158, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, LLM_ARCH_JAIS, @@ -281,6 +282,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_BITNET_25, "bitnet-25" }, + { LLM_ARCH_BITNET_B158, "bitnet-b1.58" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_JAIS, "jais" }, @@ -1419,6 +1421,34 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, }, }, + { + LLM_ARCH_BITNET_B158, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, + { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, + }, + }, { LLM_ARCH_T5, { @@ -5235,7 +5265,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); - if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON || model.arch == LLM_ARCH_BITNET_25) { + if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON || model.arch == LLM_ARCH_BITNET_25 || model.arch == LLM_ARCH_BITNET_B158) { if (hparams.n_rot != hparams.n_embd_head_k) { throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); } @@ -5830,6 +5860,7 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_BITNET_B158: case LLM_ARCH_BITNET_25: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -8350,6 +8381,7 @@ static bool llm_load_tensors( layer.ffn_up_scale = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); } } break; + case LLM_ARCH_BITNET_B158: case LLM_ARCH_BITNET_25: { model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -15376,7 +15408,7 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_bitnet_25() { + struct ggml_cgraph * build_bitnet_158() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); // mutable variable, needed during the last layer of the computation to skip unused tokens @@ -16599,9 +16631,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_bitnet(); } break; + case LLM_ARCH_BITNET_B158: case LLM_ARCH_BITNET_25: { - result = llm.build_bitnet_25(); + result = llm.build_bitnet_158(); } break; case LLM_ARCH_COHERE2: { @@ -20293,6 +20326,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_STABLELM: case LLM_ARCH_BITNET: case LLM_ARCH_BITNET_25: + case LLM_ARCH_BITNET_B158: case LLM_ARCH_QWEN: case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2MOE: From 1ea1df4b2d942ebd56efdcdfb922ec92d6dc1db7 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 2 May 2025 07:09:09 +0200 Subject: [PATCH 109/132] Fix FA bug on AVX2 (#364) * Fix FA bug on AVX2 * Also this was wrong --------- Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_mul_mat.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 5e49089a..6adc43cf 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -17120,11 +17120,12 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, vh.reset_block(); block_q8_0_r8 q8r8[Dk/QK8_0 * k_step/8]; HelperQ80R8 khr8((const char *)q8r8, Dk/QK8_0*sizeof(block_q8_0)); - HelperQ80::convert(q_step, stride_q, q, q8); + auto q8r = (typename HelperQ80R8::block_q8 *)qptr; + HelperQ80::convert(q_step, stride_q, q, q8r); auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { - HelperQ80R8::repack(k_step, kh.data, kh.stride, q8r8); - KQHelper::mul_mask_kq(khr8, stride_m, q8, mr, fms); + HelperQ80R8::repack(k_step, kh.block, kh.stride, q8r8); + KQHelper::mul_mask_kq(khr8, stride_m, q8r, mr, fms); fqkv.accumulate_qkv(vh, fms); kh.next_block(); vh.next_block(); @@ -17236,7 +17237,8 @@ struct FlashAttn { std::is_same_v> || std::is_same_v>) { constexpr size_t kMaxOnStackSize = 576; - auto q_size = q_step*(Dk/KHelper::block_size_q)*sizeof(typename KHelper::block_q8); + //auto q_size = q_step*(Dk/KHelper::block_size_q)*sizeof(typename KHelper::block_q8); + auto q_size = q_step*(Dk/QK8_2*sizeof(block_q8_2)); q_size = GGML_PAD(q_size, 64); if (q_size > kMaxOnStackSize) { auto qptr = get_q_storage(q_size); From afcfa85756ec7a476ed1daf79ed4152625dd8c7c Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 3 May 2025 14:43:55 +0300 Subject: [PATCH 110/132] Trying to fix iq1_s_r4/iq1_m_r4 quantization failure (#368) Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_quantize.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 7873d3fe..b95493eb 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -6572,8 +6572,8 @@ size_t quantize_iq1_s_r4(const float * src, void * dst, int64_t nrows, int64_t n auto xb = src + k*n_per_row + kBlockSize*ibl; float sumx2 = 0; for (int j = 0; j < kBlockSize; ++j) sumx2 += xb[j]*xb[j]; - if (!sumx2) { - printf("Found block with all zeros\n"); + if (sumx2 < 1e-14f) { + //printf("Found block with all zeros\n"); // all zero int ind = 1029; // this is the grid entry with all zeros scales[4*ibl+k] = 0; @@ -6703,13 +6703,18 @@ size_t quantize_iq1_m_r4(const float * src, void * dst, int64_t nrows, int64_t n auto xb = src + k*n_per_row + kBlockSize*ibl; float sumx2 = 0; for (int j = 0; j < kBlockSize; ++j) sumx2 += xb[j]*xb[j]; - if (!sumx2) { + if (sumx2 < 1e-14f) { scales[8*ibl+2*k+0] = scales[8*ibl+2*k+1] = 0; continue; } float sigma2 = 1.5f*sumx2/kBlockSize; if (imatrix) { for (int j = 0; j < kBlockSize; ++j) weight[j] = imatrix[kBlockSize*ibl + j]*sqrt(sigma2 + xb[j]*xb[j]); + float sumwx = 0; + for (int j = 0; j < kBlockSize; ++j) sumwx += weight[j]*std::abs(xb[j]); + if (!sumwx) { + for (int j = 0; j < kBlockSize; ++j) weight[j] = sqrt(sigma2 + xb[j]*xb[j]); + } } else { for (int j = 0; j < kBlockSize; ++j) weight[j] = sqrt(sigma2 + xb[j]*xb[j]); } From ab7f694b71497d216e1e7bad50bb4471feee7652 Mon Sep 17 00:00:00 2001 From: Gaolingx <947770192@qq.com> Date: Sat, 3 May 2025 20:56:29 +0800 Subject: [PATCH 111/132] cmake: force MSVC compiler charset to utf-8 (#369) --- CMakeLists.txt | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index cb59656e..10049391 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,6 +53,14 @@ if (WIN32) add_compile_definitions(_CRT_SECURE_NO_WARNINGS) endif() +# force MSVC compiler charset to utf-8 +if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") + add_compile_options("$<$:/source-charset:utf-8>") + add_compile_options("$<$:/source-charset:utf-8>") + add_compile_options("$<$:/execution-charset:utf-8>") + add_compile_options("$<$:/execution-charset:utf-8>") +endif() + # # option list # From b890e01238dd73e38c8cac99e34eed2071e1c4c1 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 4 May 2025 09:02:12 +0300 Subject: [PATCH 112/132] Another attempt to fix #367 (#371) * Another attempt to fix #367 * Yet another --------- Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_quantize.cpp | 38 ++++++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index b95493eb..ca5e008a 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -6701,25 +6701,49 @@ size_t quantize_iq1_m_r4(const float * src, void * dst, int64_t nrows, int64_t n for (int ibl = 0; ibl < nblock; ++ibl) { for (int k = 0; k < 4; ++k) { auto xb = src + k*n_per_row + kBlockSize*ibl; - float sumx2 = 0; - for (int j = 0; j < kBlockSize; ++j) sumx2 += xb[j]*xb[j]; + float sumx2l = 0, sumx2h = 0; + for (int j = 0; j < kBlockSize/2; ++j) sumx2l += xb[j]*xb[j]; + for (int j = kBlockSize/2; j < kBlockSize; ++j) sumx2h += xb[j]*xb[j]; + float sumx2 = sumx2l + sumx2h; if (sumx2 < 1e-14f) { scales[8*ibl+2*k+0] = scales[8*ibl+2*k+1] = 0; + int ind = 1029; + for (int i = 0; i < 4; ++i) { + y[ibl].qs[4*i + k] = ind & 255; + } + for (int i = 0; i < 2; ++i) { + y[ibl].qh[4*i+k] = (ind >> 8) | ((ind >> 8) << 4); + } continue; } float sigma2 = 1.5f*sumx2/kBlockSize; if (imatrix) { for (int j = 0; j < kBlockSize; ++j) weight[j] = imatrix[kBlockSize*ibl + j]*sqrt(sigma2 + xb[j]*xb[j]); float sumwx = 0; - for (int j = 0; j < kBlockSize; ++j) sumwx += weight[j]*std::abs(xb[j]); - if (!sumwx) { - for (int j = 0; j < kBlockSize; ++j) weight[j] = sqrt(sigma2 + xb[j]*xb[j]); + for (int j = 0; j < kBlockSize/2; ++j) sumwx += weight[j]*std::abs(xb[j]); + if (sumwx < 1e-14f) { + for (int j = 0; j < kBlockSize/2; ++j) weight[j] = sqrt(sigma2 + xb[j]*xb[j]); + } + sumwx = 0; + for (int j = kBlockSize/2; j < kBlockSize; ++j) sumwx += weight[j]*std::abs(xb[j]); + if (sumwx < 1e-14) { + for (int j = kBlockSize/2; j < kBlockSize; ++j) weight[j] = sqrt(sigma2 + xb[j]*xb[j]); } } else { for (int j = 0; j < kBlockSize; ++j) weight[j] = sqrt(sigma2 + xb[j]*xb[j]); } - iq1m_process_1block(xb+ 0, weight+ 0, L, scales.data() + 8*ibl + 2*k+0, index+0, &shift1, pairs); - iq1m_process_1block(xb+16, weight+16, L, scales.data() + 8*ibl + 2*k+1, index+2, &shift2, pairs); + if (sumx2l > 1e-14f) { + iq1m_process_1block(xb+ 0, weight+ 0, L, scales.data() + 8*ibl + 2*k+0, index+0, &shift1, pairs); + } else { + scales[8*ibl+2*k+0] = 0; + index[0] = index[1] = 1029; + } + if (sumx2h > 1e-14f) { + iq1m_process_1block(xb+16, weight+16, L, scales.data() + 8*ibl + 2*k+1, index+2, &shift2, pairs); + } else { + scales[8*ibl+2*k+1] = 0; + index[2] = index[3] = 1029; + } max[k] = std::max(max[k], std::max(scales[8*ibl+2*k+0], scales[8*ibl+2*k+1])); for (int i = 0; i < 4; ++i) { y[ibl].qs[4*i + k] = index[i] & 255; From ce2b0292e18cd8dd87776797fa455e7fc4cfeed9 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 4 May 2025 09:17:44 +0300 Subject: [PATCH 113/132] CUDA: faster FA TG for GQA models (#370) * cuda: WIP MMA FA * Use MMA for TG also when quantized --------- Co-authored-by: Iwan Kawrakow --- ggml/src/CMakeLists.txt | 2 + ggml/src/ggml-cuda/common.cuh | 56 + ggml/src/ggml-cuda/cp-async.cuh | 46 + ggml/src/ggml-cuda/fattn-common.cuh | 333 ++++++ ggml/src/ggml-cuda/fattn-mma-f16.cuh | 1047 +++++++++++++++++ ggml/src/ggml-cuda/fattn.cu | 129 +- ggml/src/ggml-cuda/mma_new.cuh | 396 +++++++ ...attn-mma-f16-instance-ncols1_1-ncols2_8.cu | 10 + ...ttn-mma-f16-instance-ncols1_16-ncols2_1.cu | 10 + ...ttn-mma-f16-instance-ncols1_16-ncols2_2.cu | 10 + ...ttn-mma-f16-instance-ncols1_16-ncols2_4.cu | 10 + ...attn-mma-f16-instance-ncols1_2-ncols2_4.cu | 10 + ...attn-mma-f16-instance-ncols1_2-ncols2_8.cu | 10 + ...ttn-mma-f16-instance-ncols1_32-ncols2_1.cu | 10 + ...ttn-mma-f16-instance-ncols1_32-ncols2_2.cu | 10 + ...attn-mma-f16-instance-ncols1_4-ncols2_2.cu | 10 + ...attn-mma-f16-instance-ncols1_4-ncols2_4.cu | 10 + ...attn-mma-f16-instance-ncols1_4-ncols2_8.cu | 10 + ...ttn-mma-f16-instance-ncols1_64-ncols2_1.cu | 10 + ...attn-mma-f16-instance-ncols1_8-ncols2_1.cu | 10 + ...attn-mma-f16-instance-ncols1_8-ncols2_2.cu | 10 + ...attn-mma-f16-instance-ncols1_8-ncols2_4.cu | 10 + ...attn-mma-f16-instance-ncols1_8-ncols2_8.cu | 10 + 23 files changed, 2158 insertions(+), 11 deletions(-) create mode 100644 ggml/src/ggml-cuda/cp-async.cuh create mode 100644 ggml/src/ggml-cuda/fattn-mma-f16.cuh create mode 100644 ggml/src/ggml-cuda/mma_new.cuh create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 17a9e2eb..74ac5374 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -321,6 +321,8 @@ if (GGML_CUDA) list(APPEND GGML_SOURCES_CUDA "ggml-cuda.cu") file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "ggml-cuda/template-instances/fattn-mma*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "ggml-cuda/template-instances/mmq*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 2eba527f..0a7f7f83 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -46,10 +46,14 @@ #define CC_VOLTA 700 #define CC_TURING 750 #define CC_AMPERE 800 +#define CC_ADA_LOVELACE 890 #define CC_OFFSET_AMD 1000000 +#define CC_OFFSET_MTHREADS 0x0100000 #define CC_RDNA1 (CC_OFFSET_AMD + 1010) #define CC_RDNA2 (CC_OFFSET_AMD + 1030) #define CC_RDNA3 (CC_OFFSET_AMD + 1100) +#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < CC_OFFSET_MTHREADS) +#define GGML_CUDA_CC_IS_AMD(cc) (cc >= CC_OFFSET_AMD) #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses @@ -134,6 +138,49 @@ typedef float2 dfloat2; #define INT8_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE +#define CP_ASYNC_AVAILABLE +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE + +#ifdef __CUDA_ARCH_LIST__ +constexpr bool ggml_cuda_has_arch_impl(int) { + return false; +} + +template +constexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) { + return arch == first || ggml_cuda_has_arch_impl(arch, rest...); +} + +constexpr bool ggml_cuda_has_arch(const int arch) { + return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__); +} + +constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) { + if (cur == 0) { + GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch); + } + return cur; +} + +template +constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur, const int first, Archs... rest) { + if (first <= arch && first > cur) { + return ggml_cuda_highest_compiled_arch_impl(arch, first, rest...); + } else { + return ggml_cuda_highest_compiled_arch_impl(arch, cur, rest...); + } +} + +constexpr int ggml_cuda_highest_compiled_arch(const int arch) { + return ggml_cuda_highest_compiled_arch_impl(arch, 0, __CUDA_ARCH_LIST__); +} +#else +static int ggml_cuda_highest_compiled_arch(const int arch) { + return arch; +} +#endif // __CUDA_ARCH_LIST__ + static constexpr bool fast_fp16_available(const int cc) { return cc >= CC_PASCAL && cc != 610; } @@ -146,6 +193,15 @@ static constexpr bool int8_mma_available(const int cc) { return cc < CC_OFFSET_AMD && cc >= CC_TURING; } +// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. +static bool new_mma_available(const int cc) { + return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= CC_TURING; +} + +static bool cp_async_available(const int cc) { + return cc < CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= CC_AMPERE; +} + [[noreturn]] static __device__ void no_device_code( const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) { diff --git a/ggml/src/ggml-cuda/cp-async.cuh b/ggml/src/ggml-cuda/cp-async.cuh new file mode 100644 index 00000000..ecb65999 --- /dev/null +++ b/ggml/src/ggml-cuda/cp-async.cuh @@ -0,0 +1,46 @@ +// Simplified API for asynchronous data loading. + +#include "common.cuh" + +// Copies data from global to shared memory, cg == cache global. +// Both the src and dst pointers must be aligned to 16 bit. +// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int. +// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared. +// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements. +template +static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) { + static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload"); +#ifdef CP_ASYNC_AVAILABLE +#if CUDART_VERSION >= 11040 + if (preload == 256) { + asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } else if (preload == 128) { + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } else if (preload == 64) { + asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } else +#endif // CUDART_VERSION >= 11040 + { + asm volatile("cp.async.cg.shared.global [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } +#else + GGML_UNUSED(dst); + GGML_UNUSED(src); + NO_DEVICE_CODE; +#endif // CP_ASYNC_AVAILABLE +} + +// Makes each thread wait until its asynchronous data copies are done. +// This does NOT provide any additional synchronization. +// In particular, when copying data with multiple warps a call to __syncthreads will be needed. +static __device__ __forceinline__ void cp_async_wait_all() { +#ifdef CP_ASYNC_AVAILABLE + asm volatile("cp.async.wait_all;"); +#else + NO_DEVICE_CODE; +#endif // CP_ASYNC_AVAILABLE +} diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 08bb5cd3..5e8ee0f6 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -862,3 +862,336 @@ void launch_fattn( (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); CUDA_CHECK(cudaGetLastError()); } + +template // D == head size +__launch_bounds__(D, 1) +static __global__ void flash_attn_mma_stream_k_fixup( + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { + constexpr int ncols = ncols1*ncols2; + + const int bidx0 = blockIdx.x; + const int j = blockIdx.y; + const int c = blockIdx.z; + const int jc = j*ncols2 + c; + const int tid = threadIdx.x; + + const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); + + const int iter_k = ne11 / FATTN_KQ_STRIDE; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + + const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + + const bool did_not_have_any_data = kbc0 == kbc0_stop; + const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; + const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0; + if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { + return; + } + + const int channel = kbc0 / (iter_k*iter_j); + const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k; + + if (jt*ncols1 + j >= ne01) { + return; + } + + dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid; + + // Load the partial result that needs a fixup: + float dst_val = 0.0f; + float max_val = 0.0f; + float rowsum = 0.0f; + { + dst_val = *dst; + + const float2 tmp = dst_fixup[bidx0*ncols + jc]; + max_val = tmp.x; + rowsum = tmp.y; + } + + + // Iterate over previous blocks and compute the combined results. + // All CUDA blocks that get here must have a previous block that needs a fixup. + int bidx = bidx0 - 1; + int kbc_stop = kbc0; + while(true) { + const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + if (kbc == kbc_stop) { // Did not have any data. + bidx--; + kbc_stop = kbc; + continue; + } + + const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid]; + + const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc]; + + // Scale the current and new value accumulators depending on the max. values. + const float max_val_new = fmaxf(max_val, tmp.x); + + const float diff_val = max_val - max_val_new; + const float diff_add = tmp.x - max_val_new; + + const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f; + const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f; + + dst_val = scale_val*dst_val + scale_add*dst_add; + rowsum = scale_val*rowsum + scale_add*tmp.y; + + max_val = max_val_new; + + // If this block started in a previous tile we are done and don't need to combine additional partial results. + if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) { + break; + } + bidx--; + kbc_stop = kbc; + } + + // Write back final result: + *dst = dst_val / rowsum; +} + +template // D == head size +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(D, 1) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void flash_attn_mma_combine_results( + const float * __restrict__ VKQ_parts, + const float2 * __restrict__ VKQ_meta, + float * __restrict__ dst, + const int parallel_blocks) { + VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x; + VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x; + dst += D * gridDim.z*blockIdx.x; + + const int tid = threadIdx.x; + __builtin_assume(tid < D); + + extern __shared__ float2 meta[]; + if (tid < 2*parallel_blocks) { + ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid]; + } + + __syncthreads(); + + float kqmax = meta[0].x; + for (int l = 1; l < parallel_blocks; ++l) { + kqmax = max(kqmax, meta[l].x); + } + + float VKQ_numerator = 0.0f; + float VKQ_denominator = 0.0f; + for (int l = 0; l < parallel_blocks; ++l) { + const float diff = meta[l].x - kqmax; + float KQ_max_scale = expf(diff); + const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); + *((uint32_t *) &KQ_max_scale) &= ftz_mask; + + VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid]; + VKQ_denominator += KQ_max_scale * meta[l].y; + } + + dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator; +} + +template +void launch_fattn_mma( + ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, + const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE +) { + constexpr int ncols = ncols1 * ncols2; + + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + const ggml_tensor * mask = dst->src[3]; + + ggml_tensor * KQV = dst; + + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); + GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + + GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); + + GGML_ASSERT(Q->ne[3] == 1); + + ggml_cuda_pool & pool = ctx.pool(); + cudaStream_t main_stream = ctx.stream(); + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int nsm = ggml_cuda_info().devices[id].nsm; + + ggml_cuda_pool_alloc K_f16(pool); + ggml_cuda_pool_alloc V_f16(pool); + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + const char * K_data = (const char *) K->data; + size_t nb11 = K->nb[1]; + size_t nb12 = K->nb[2]; + size_t nb13 = K->nb[3]; + + const char * V_data = (const char *) V->data; + size_t nb21 = V->nb[1]; + size_t nb22 = V->nb[2]; + size_t nb23 = V->nb[3]; + + if (need_f16_K && K->type != GGML_TYPE_F16) { + K_f16.alloc(ggml_nelements(K)); + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); + to_fp16(K_data, K_f16.ptr, 1, ggml_nelements(K), main_stream); + K_data = (char *) K_f16.ptr; + + const size_t bs = ggml_blck_size(K->type); + const size_t ts = ggml_type_size(K->type); + + nb11 = nb11*bs*sizeof(half)/ts; + nb12 = nb12*bs*sizeof(half)/ts; + nb13 = nb13*bs*sizeof(half)/ts; + } + + if (need_f16_V && V->type != GGML_TYPE_F16) { + V_f16.alloc(ggml_nelements(V)); + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); + to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream); + V_data = (char *) V_f16.ptr; + + const size_t bs = ggml_blck_size(V->type); + const size_t ts = ggml_type_size(V->type); + + nb21 = nb21*bs*sizeof(half)/ts; + nb22 = nb22*bs*sizeof(half)/ts; + nb23 = nb23*bs*sizeof(half)/ts; + } + + int parallel_blocks = 1; + + const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); + const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; + + const dim3 block_dim(warp_size, nwarps, 1); + dim3 blocks_num; + if (stream_k) { + // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. + const int max_blocks = 2*nsm; + const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks; + const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves); + + const int nblocks_stream_k = max_blocks; + + const bool use_stream_k = cc >= CC_ADA_LOVELACE || tiles_efficiency_percent < 75; + + blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total; + blocks_num.y = 1; + blocks_num.z = 1; + + dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float)); + } else { + GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0); + const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. + + int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. + CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); + + // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave: + parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1); + + // parallel_blocks must not be larger than what the tensor size allows: + parallel_blocks = std::min(parallel_blocks, ntiles_KQ); + + // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects. + // Test whether parallel_blocks can be set to a higher value for better efficiency. + const int blocks_per_wave = nsm * max_blocks_per_sm; + int nwaves_best = 0; + int efficiency_percent_best = 0; + for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) { + const int nblocks_total = ntiles_total * parallel_blocks_test; + const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave; + const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave); + + // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead. + if (efficiency_percent_best >= 90 && nwaves > nwaves_best) { + break; + } + + if (efficiency_percent > efficiency_percent_best) { + nwaves_best = nwaves; + efficiency_percent_best = efficiency_percent; + parallel_blocks = parallel_blocks_test; + } + } + + blocks_num.x = ntiles_x; + blocks_num.y = parallel_blocks; + blocks_num.z = Q->ne[2]*Q->ne[3]; + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + } + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + + const uint32_t n_head = Q->ne[2]; + const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head)))); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + GGML_ASSERT(block_dim.x % warp_size == 0); + fattn_kernel<<>>( + (const char *) Q->data, + K_data, + V_data, + mask ? ((const char *) mask->data) : nullptr, + !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, + scale, max_bias, m0, m1, n_head_log2, logit_softcap, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + nb11, nb12, nb13, + nb21, nb22, nb23, + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if (stream_k) { + if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; + + flash_attn_mma_stream_k_fixup + <<>> + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); + } + } else if (parallel_blocks > 1) { + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z); + const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); + + flash_attn_mma_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); + } + CUDA_CHECK(cudaGetLastError()); +} + diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh new file mode 100644 index 00000000..af38071d --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -0,0 +1,1047 @@ +#include "common.cuh" +#include "cp-async.cuh" +#include "mma_new.cuh" +#include "fattn-common.cuh" + +using namespace ggml_cuda_mma; + +typedef tile<16, 8, half2> tile_A; +typedef tile< 8, 8, half2> tile_B; +typedef tile<16, 8, half2> tile_B_16; +typedef tile<16, 8, float> tile_C_KQ; +typedef tile<16, 16, float> tile_C_KQ_16; +typedef tile<16, 4, half2> tile_C_VKQ; +typedef tile<16, 8, half2> tile_C_VKQ_16; + +template +static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) { + constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + + // If cp.async is available, load up to the highest power of 2 in D asynchronously: +#ifdef CP_ASYNC_AVAILABLE + static_assert(D >= 64 && D < 512, "bad D"); + constexpr int k0_sync_start = D/2 < 64 ? 32 : (D/2 < 128 ? 64 : 128); + + const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV); + + constexpr int preload = 64; + constexpr int h2_per_chunk = 16/sizeof(half2); + constexpr int chunks_per_row = k0_sync_start / h2_per_chunk; + constexpr int stride_i = WARP_SIZE / chunks_per_row; +#pragma unroll + for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row); + const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk; + + cp_async_cg_16(tile_KV_32 + (i*D2_padded + k)*sizeof(half2), KV + i*stride_KV + k); + } +#else + constexpr int k0_sync_start = 0; +#endif // CP_ASYNC_AVAILABLE + static_assert(k0_sync_start % WARP_SIZE == 0, "bad k0_sync_start"); + + // If D is not a power of 2, the rest is loaded synchronously. + // K/V data is loaded with decreasing granularity for D for better memory bandwidth. + static_assert(KQ_per_iter % (4*nwarps) == 0, "out of bounds"); +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; + + if (k0_start == k0_stop || k0_stop <= k0_sync_start) { + continue; + } + +#pragma unroll + for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + tile_KV[i*D2_padded + k] = KV[i*stride_KV + k]; + } + } + } +} + +template +static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( + const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) { + static_assert(KQ_per_iter == 2*WARP_SIZE || KQ_per_iter == WARP_SIZE, "bad KQ_per_iter"); +#ifdef CP_ASYNC_AVAILABLE + constexpr int preload = KQ_per_iter * sizeof(half); + constexpr int cols_per_warp = 8*WARP_SIZE/KQ_per_iter; + constexpr int stride_j = nwarps * cols_per_warp; + + const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask); + +#pragma unroll + for (int j0 = 0; j0 < ncols1; j0 += stride_j) { + const int j = j0 + threadIdx.y*cols_per_warp + + (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/8)); + + if (j0 + stride_j > ncols1 && j >= ncols1) { + break; + } + + const int i = 4 * (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x % (WARP_SIZE/4) : threadIdx.x % (WARP_SIZE/8)); + + cp_async_cg_16(tile_mask_32 + j*(KQ_per_iter*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i); + } +#else + constexpr int cols_per_warp = 2*WARP_SIZE/KQ_per_iter; + constexpr int stride_j = nwarps * cols_per_warp; +#pragma unroll + for (int j0 = 0; j0 < ncols1; j0 += stride_j) { + const int j = j0 + threadIdx.y*cols_per_warp + (KQ_per_iter == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/2)); + + if (j0 + stride_j > ncols1 && j >= ncols1) { + break; + } + + const int i = KQ_per_iter == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/2); + + tile_mask[j*(KQ_per_iter/2 + 4) + i] = mask_h2[j*stride_mask + i]; + } +#endif // CP_ASYNC_AVAILABLE +} + +template +static __device__ __forceinline__ void flash_attn_ext_f16_iter( + const float2 * const __restrict__ Q_f2, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half2 * const __restrict__ mask_h2, + float2 * const __restrict__ dstk, + float2 * const __restrict__ dstk_fixup, + const float scale, + const float slope, + const float logit_softcap, + const int ne01, + const int ne02, + const int stride_KV, + const int stride_mask, + const int jt, + half2 * const __restrict__ tile_K, + half2 * const __restrict__ tile_V, + half2 * const __restrict__ tile_mask, + const tile_B * const __restrict__ Q_B, + tile_C_VKQ * const __restrict__ VKQ_C, + float * const __restrict__ KQ_max, + float * const __restrict__ KQ_rowsum, + const int kb0) { +#ifdef INT8_MMA_AVAILABLE + constexpr int cols_per_warp = ntiles * tile_B::I; + constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + + const int k_VKQ_0 = kb0 * KQ_per_iter; + tile_C_KQ KQ_C[KQ_per_iter/(np*tile_C_KQ::I) * ntiles]; + + // Use wide variants of tiles if ntiles >= 2. + tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; + tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; + tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; + +#ifdef CP_ASYNC_AVAILABLE + cp_async_wait_all(); + __syncthreads(); + flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); +#else + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV); + __syncthreads(); +#endif // CP_ASYNC_AVAILABLE + + // Calculate tile of KQ: +#pragma unroll + for (int i_KQ_00 = 0; i_KQ_00 < KQ_per_iter; i_KQ_00 += np*tile_A::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) { + tile_A K_A; + load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded); + if (ntiles == 1) { + mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + // Wide version of KQ_C is column-major => swap A and B. + mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A); + } + } + } + } + +#ifndef CP_ASYNC_AVAILABLE + __syncthreads(); // Only needed if tile_K == tile_V. +#endif // CP_ASYNC_AVAILABLE + + if (use_logit_softcap) { + static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int i = 0; i < KQ_per_iter/(np*tile_C_KQ::I) * ntiles; ++i) { +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); + } + } + } + + float KQ_max_new[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + KQ_max_new[col] = KQ_max[col]; + } + float KQ_rowsum_add[cols_per_thread] = {0.0f}; + + if (ntiles == 1) { + if (ncols2 > 1 || mask_h2) { +#pragma unroll + for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ::I) { + const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + const int i = i0 + tile_C_KQ::get_i(l); + const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2; + + KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope * + __half2float(((const half *) tile_mask)[j*(KQ_per_iter + 8) + i]); + } + } + } + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. + static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) { +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]); + } + } + + // Values per KQ column are spread across 8 threads, does not need full warp reduce: +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = 16; offset >= 4; offset >>= 1) { + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + } + } + + static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); + +#pragma unroll + for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) { +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]); + + KQ_rowsum_add[l % 2] += KQ_C[k].x[l]; + } + } + } else { // ntiles > 1 + if (ncols2 > 1 || mask_h2) { +#pragma unroll + for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ_16::J) { + const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J; +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { +#pragma unroll + for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) { + const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2; + const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2; + + const float2 tmp = __half22float2(tile_mask[j*(KQ_per_iter/2 + 4) + i]); + const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t; + KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x; + KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y; + } + } + } + } + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. + static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { +#pragma unroll + for (int l = 0; l < tile_C_KQ_16::ne; ++l) { + const int KQ_index = 2*t + (l/2) % 2; + KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]); + } + } + } + + // Values per KQ column are spread across 4 threads, does not need full warp reduce: +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = 2; offset >= 1; offset >>= 1) { + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + } + } + + static_assert(KQ_per_iter % (np*tile_C_KQ_16::J) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { +#pragma unroll + for (int l = 0; l < tile_C_KQ_16::ne; ++l) { + const int KQ_index = 2*t + (l/2) % 2; + + KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]); + + KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l]; + } + } + } + } + + { + float KQ_max_scale[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]); + KQ_max[col] = KQ_max_new[col]; + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col]; + } + + if (ntiles == 1) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); +#pragma unroll + for (int i = 0; i < D/tile_C_VKQ::I; ++i) { +#pragma unroll + for (int l = 0; l < tile_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); +#pragma unroll + for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) { +#pragma unroll + for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { + VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + } + } + } + } + } + + // Convert KQ C tiles into B tiles for VKQ calculation: + tile_B B[KQ_per_iter/(np*2*tile_B::J) * ntiles]; + tile_B_16 * B_16 = (tile_B_16 *) B; + static_assert(KQ_per_iter % (np*2*tile_B::J) == 0, "bad loop size"); + if (ntiles == 1) { +#pragma unroll + for (int k = 0; k < KQ_per_iter/(np*2*tile_B::J); ++k) { + B[k] = get_transposed(get_half2(KQ_C[k])); + } + } else { + for (int k = 0; k < KQ_per_iter/(np*2*tile_B_16::J); ++k) { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]); + } + } + } + +#ifdef CP_ASYNC_AVAILABLE + // Preload K tile for next iteration: + cp_async_wait_all(); + __syncthreads(); + if (!last_iter) { + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask(mask_h2 + (k_VKQ_0 + KQ_per_iter)/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile(K_h2 + (k_VKQ_0 + KQ_per_iter)*stride_KV, tile_K, stride_KV); + } +#else + flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); + __syncthreads(); +#endif // CP_ASYNC_AVAILABLE + + // Calculate VKQ tile: +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) { + static_assert((KQ_per_iter/2) % (np*tile_A::J) == 0, "bad loop size"); +#pragma unroll + for (int k00 = 0; k00 < KQ_per_iter/2; k00 += np*tile_A::J) { + const int k0 = k00 + (threadIdx.y % np)*tile_A::J; + + tile_A A; + load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded); + if (ntiles == 1) { + mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + // Wide version of VKQ_C is column-major => swap A and B. + mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A); + } + } + } + } + +#ifndef CP_ASYNC_AVAILABLE + __syncthreads(); // Only needed if tile_K == tile_V. +#endif // CP_ASYNC_AVAILABLE + +#else + GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); + GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); + GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_KV); + GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); + GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); + GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B); + GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum); + GGML_UNUSED(kb0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE +} + +template +static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( + const float2 * const __restrict__ Q_f2, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half2 * const __restrict__ mask_h2, + float2 * const __restrict__ dstk, + float2 * const __restrict__ dstk_fixup, + const float scale, + const float slope, + const float logit_softcap, + const int ne01, + const int ne02, + const int stride_Q1, + const int stride_Q2, + const int stride_KV, + const int stride_mask, + const int jt, + const int kb0_start, + const int kb0_stop) { +#ifdef INT8_MMA_AVAILABLE + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + constexpr int ncols = ncols1 * ncols2; + constexpr int cols_per_warp = ntiles * tile_B::I; + constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + + static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); + + static_assert(D % nwarps == 0, "bad D"); + static_assert(KQ_per_iter % nwarps == 0, "bad KQ_per_iter"); + + constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + + // Temporary shared buffer for loading K/V data with KQ_per_iter*D logical elements: + extern __shared__ half2 tile_K[]; +#ifdef CP_ASYNC_AVAILABLE + half2 * tile_V = tile_K + KQ_per_iter*D2_padded; +#else + half2 * tile_V = tile_K; +#endif // CP_ASYNC_AVAILABLE + half2 * tile_mask = tile_V + KQ_per_iter*D2_padded; + + tile_B Q_B[D/(2*tile_B::J) * ntiles]; + tile_C_VKQ VKQ_C[D/tile_C_VKQ::I * ntiles]; + + tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; + tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; + + float KQ_rowsum[cols_per_thread] = {0.0f}; + float KQ_max[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + KQ_max[col] = -FLT_MAX/2.0f; + } + + // Temporarily load Q data into tile_K, will be loaded into registers afterwards. + // The loading is done with decreasing granularity for D for better memory bandwidth. + const half2 scale_h2 = make_half2(scale, scale); +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_jc = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + continue; + } + +#pragma unroll + for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) { + const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) { + break; + } + + const int j = jc / ncols2; + const int c = jc % ncols2; + + if (jt*ncols1 + j < ne01) { +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k]; + tile_K[jc*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); + } + } else { +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + tile_K[jc*D2_padded + k] = make_half2(0.0f, 0.0f); + } + } + } + } + + __syncthreads(); + + { + const int j0 = (threadIdx.y / np) * cols_per_warp; + +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { + if (ntiles == 1) { + load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t], + tile_K + (j0 + t*tile_B_16::I)*D2_padded + k0, D2_padded); + } + } + } + } + + __syncthreads(); + + // Preload mask and K data for first iteration when using cp_async: +#ifdef CP_ASYNC_AVAILABLE + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask(mask_h2 + kb0_start*KQ_per_iter/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile(K_h2 + kb0_start*KQ_per_iter*stride_KV, tile_K, stride_KV); +#endif // CP_ASYNC_AVAILABLE + + // Iterate over ne11 == previous tokens: + for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { + constexpr bool last_iter = false; + flash_attn_ext_f16_iter + (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + } + { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. + constexpr bool last_iter = true; + flash_attn_ext_f16_iter + (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); + } + + // With cp_async there is no __syncthreads at the end of the iter, + // there can be a race condition on shared memory access for combining/writing back results. +#ifdef CP_ASYNC_AVAILABLE + if (nwarps*cols_per_warp > KQ_per_iter) { + __syncthreads(); + } +#endif // CP_ASYNC_AVAILABLE + + // Finally, sum up partial KQ rowsums. + // The partial sums are spread across 8/4 threads each, does not need full reduce. + { + constexpr int offset_first = ntiles == 1 ? 16 : 2; + constexpr int offset_last = ntiles == 1 ? 4 : 1; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = offset_first; offset >= offset_last; offset >>= 1) { + KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE); + } + } + } + + // Write VKQ accumulators to shared memory in column-major format. + // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. + // Also for np > 1 the combination is done via these values in shared memory. + if (ntiles == 1) { + const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { + const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. + +#pragma unroll + for (int l = 0; l < tile_B::ne; ++l) { + const int k = k0 + tile_B::get_j(l); + + tile_K[jc_cwd*D2_padded + k] = B.x[l]; + } + } + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I; +#pragma unroll + for (int k0 = 0; k0 < D/2; k0 += tile_C_VKQ_16::J) { +#pragma unroll + for (int l = 0; l < tile_C_VKQ_16::ne; ++l) { + const int j = j0 + tile_C_VKQ_16::get_i(l); + const int k = k0 + tile_C_VKQ_16::get_j(l); + + tile_K[j*D2_padded + k] = VKQ_C_16[k0/tile_C_VKQ_16::J * ntiles/2 + t].x[l]; + } + } + } + } + + if constexpr (ntiles == 1) { + const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset + const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta + const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum + + if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) { + // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. + ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr; + } + + __syncthreads(); + + if (np == 1) { + // No combination is needed, the meta data can be directly written from registers to VRAM. + if (needs_fixup && threadIdx.x < tile_B::I) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + if (is_fixup && threadIdx.x < tile_B::I) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + } + } else { + static_assert(ntiles == 2 || ntiles == 4, "bad ntiles"); + const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta + + (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0) + + tile_C_VKQ_16::get_i(threadIdx.x % 4); + const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum + + if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) { + // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. + ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr; + } + + __syncthreads(); + + if (np == 1) { + // No combination is needed, the meta data can be directly written from registers to VRAM. + if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + } + } + + static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles"); + if (np > 1 && threadIdx.y % np == 0) { + // Combine the meta data for parallel warps via shared memory. + // Warps with threadIdx.y % np != 0 must NOT return early. + // All threads must return simultaneously to avoid race conditions with work on the next tile. + + constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1; + + const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); + float2 * const meta_ptr = ((float2 *) tile_K) + jc_meta*(D2_padded/2) + D/4; + float2 meta[nmeta]; +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + meta[imeta] = meta_ptr[imeta * WARP_SIZE * D2_padded/2]; + } + + float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps. +#pragma unroll + for (int imeta = 1; imeta < nmeta; ++imeta) { + KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x); + } +#pragma unroll + for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { + if (offset >= WARP_SIZE) { + continue; + } + KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); + } + + float KQ_cms[nmeta]; // KQ combine max scale per warp. +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn); + } + + float KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps. +#pragma unroll + for (int imeta = 1; imeta < nmeta; ++imeta) { + KQ_crs += KQ_cms[imeta]*meta[imeta].y; + } +#pragma unroll + for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { + if (offset >= WARP_SIZE) { + continue; + } + KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); + } + + // Write back combined meta data: +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) { + // Combined KQ max scale + rowsum. + meta_ptr[imeta * WARP_SIZE * D2_padded/2] = make_float2(KQ_cms[imeta], KQ_crs); + } + } + + // Combined KQ max + rowsum. + static_assert(cols_per_warp <= WARP_SIZE); + if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + } + if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + } + } + + if (np > 1) { + __syncthreads(); + } + + if (np == 1 || threadIdx.y % np == 0) { + // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums. + // The values after that are for the partial results of the individual blocks. + float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2)); + +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_jc = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + continue; + } + +#pragma unroll + for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) { + const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) { + break; + } + + const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp; + + const int j_dst = jc_dst / ncols2; + const int c_dst = jc_dst % ncols2; + + if (!is_fixup && jt*ncols1 + j_dst >= ne01) { + continue; + } + + const float * meta_j = (const float *) tile_K + jc_tile_K*D2_padded + D/2; +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + float2 dstk_val = make_float2(0.0f, 0.0f); +#pragma unroll + for (int ip = 0; ip < np; ++ip) { + const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * D2_padded + 0]; + const float2 dstk_val_add = __half22float2(tile_K[(jc_tile_K + ip*cols_per_warp) * D2_padded + k]); + dstk_val.x += dstk_val_add.x*KQ_crs; + dstk_val.y += dstk_val_add.y*KQ_crs; + } + + if (!needs_fixup && !is_fixup) { + const float KQ_rowsum_j = meta_j[1]; + dstk_val.x /= KQ_rowsum_j; + dstk_val.y /= KQ_rowsum_j; + } + + if (is_fixup) { + dstk_fixup_data[jc_dst*(D/2) + k] = dstk_val; + } else { + dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(D/2) + k] = dstk_val; + } + } + } + } + } + + if (np > 1) { + __syncthreads(); + } +#else + GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); + GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); + GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1); + GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_KV); GGML_UNUSED(stride_mask); + GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE +} + +template +__launch_bounds__(nwarps*WARP_SIZE, 2) +static __global__ void flash_attn_mma_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const float logit_softcap, + const uint32_t n_head_log2, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int nb21, + const int nb22, + const int nb23, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { +#if defined(INT8_MMA_AVAILABLE) + + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + + static_assert(FATTN_KQ_STRIDE % KQ_per_iter == 0, "bad KQ_per_iter"); + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + + const int stride_Q1 = nb01 / sizeof(float2); + const int stride_Q2 = nb02 / sizeof(float2); + const int stride_KV = nb11 / sizeof(half2); + const int stride_mask = nb31 / sizeof(half2); + + const int iter_k = ne11 / FATTN_KQ_STRIDE; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + + constexpr int kb_niter = FATTN_KQ_STRIDE / KQ_per_iter; // Number of kernel iterations per assigned KQ slice. + + // kbc == k block continuous, current index in continuous ijk space. + int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + + // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. + // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). + // In the most general case >2 seams can fall into the same tile. + + // kb0 == k start index when in the output tile. + int kb0_start = kbc % iter_k; + int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); + while (kbc < kbc_stop && kb0_stop == iter_k) { + const int channel = kbc / (iter_k*iter_j); + const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + + const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); + const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape + const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; + float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); + + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; + + const int kb0_start_kernel = kb0_start * kb_niter; + const int kb0_stop_kernel = kb0_stop * kb_niter; + + constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. + if (kb0_start == 0) { + constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + } else { + constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + } + + kbc += iter_k; + kbc -= kbc % iter_k; + + kb0_start = 0; + kb0_stop = min(iter_k, kbc_stop - kbc); + } + + if (kbc >= kbc_stop) { + return; + } + + const int channel = kbc / (iter_k*iter_j); + const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + + const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); + const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape + const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; + float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); + + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; + + const int kb0_start_kernel = kb0_start * kb_niter; + const int kb0_stop_kernel = kb0_stop * kb_niter; + + constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. + constexpr bool needs_fixup = false; + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); +#else + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); + GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); + GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); + GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); + GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); + GGML_UNUSED(ne2); GGML_UNUSED(ne3); + NO_DEVICE_CODE; +#endif // defined(INT8_MMA_AVAILABLE) +} + +template +void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + constexpr int ncols = ncols1 * ncols2; + constexpr int KQ_per_iter = D <= 128 && ncols1 <= 64 ? 64 : 32; + constexpr int nwarps = (KQ_per_iter == 32 && ncols <= 16) ? 2 : 4; + constexpr int ntiles = ncols <= 8 ? 1 : (ncols <= 64 ? 2 : 4); + constexpr int cols_per_warp = ntiles * tile_B::I; + + static_assert(D % tile_B::J == 0, "bad D"); + static_assert(ncols % cols_per_warp == 0, "bad ncols"); + + const ggml_tensor * KQV = dst; + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + + const int KQ_shared_rows = cp_async_available(cc) ? 2*KQ_per_iter : KQ_per_iter; + + const size_t nbytes_shared_KV = KQ_shared_rows * (D + 8) * sizeof(half); + const size_t nbytes_shared_mask = ncols1 * (KQ_per_iter + 8) * sizeof(half); + const size_t nbytes_shared_combine = nwarps*cols_per_warp * (D + 8) * sizeof(half); + + const size_t nbytes_shared_total = std::max(nbytes_shared_KV + nbytes_shared_mask, nbytes_shared_combine); + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_mma_ext_f16; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_mma_ext_f16; + } + + launch_fattn_mma + (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true); +} + + +#define DECL_FATTN_MMA_F16_CASE(D, ncols1, ncols2) \ + template void ggml_cuda_flash_attn_ext_mma_f16_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(D, ncols) \ + extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/1, 1); \ + extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/2, 2); \ + extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \ + extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \ + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8) + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16) + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32) + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64) + +// Kernels with ncols == 128 are only 4% faster due to register pressure. +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128) // Needs too much shared memory. diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 55116058..fce66c20 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -12,10 +12,94 @@ #include "fattn-vec-f16.cuh" #include "fattn-vec-f32.cuh" #include "fattn-wmma-f16.cuh" +#include "fattn-mma-f16.cuh" #include "fattn.cuh" #include +template +static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + if (Q->ne[1] <= 8/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + + if (Q->ne[1] <= 16/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + + if (Q->ne[1] <= 32/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); +} + +template +static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} + +static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * mask = dst->src[3]; + + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + const float use_gqa_opt = mask && max_bias == 0.0f; + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + + if (use_gqa_opt && gqa_ratio % 8 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio == 4) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio == 2) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst); +} + static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; @@ -371,8 +455,11 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg } void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; @@ -389,7 +476,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } if (!fast_fp16_available(cc)) { - if (Q->ne[1] <= 8) { + if (Q->ne[1] <= 8 || Q->ne[0] == 256) { ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); } else { ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); @@ -398,23 +485,43 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } if (!fp16_mma_available(cc)) { - if (Q->ne[1] <= 8) { - ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); + if (precision == GGML_PREC_DEFAULT) { + if (Q->ne[1] <= 8 || Q->ne[0] == 256) { + ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); + } } else { - ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); + if (Q->ne[1] <= 8 || Q->ne[0] == 256) { + ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); + } } return; } - if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) { + const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations + // So, not sure why in mainline they thought that for CC_ADA_LOVELACE or when KV cache is not f16 the vector kernels are faster. + // On my GPU (RTX-4080) MMA is efinitely faster for GQA, both for f16 and for quantized KV cache. + //const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; + //const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion; + const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies; + const bool can_use_vector_kernel = Q->ne[0] % (2*WARP_SIZE) == 0; + if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { if (precision == GGML_PREC_DEFAULT) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); - return; - } else if(Q->ne[0] <= 128) { + } else { ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); - return; } + return; } - ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); + // We need this because I haven't adapted the MMA kernels to work for different + // K and V head sizes. + if (K->ne[0] != V->ne[0]) { + ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); + } + + ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); } diff --git a/ggml/src/ggml-cuda/mma_new.cuh b/ggml/src/ggml-cuda/mma_new.cuh new file mode 100644 index 00000000..5bba41c6 --- /dev/null +++ b/ggml/src/ggml-cuda/mma_new.cuh @@ -0,0 +1,396 @@ +// This file contains primitives that expose the tensor core PTX instructions for CUDA code. +// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout. +// The documentation for the PTX instructions can be found under: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction +// +// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C. +// A is a row-major matrix with shape M x K. +// B is a column-major matrix with shape K x N. +// C is a column-major matrix with shape M x N. +// A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns. +// Note that J is measured in physical 32 bit elements instead of logical elements. +// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile. +// All matrix tiles have ne physical 32 bit elements per warp. +// +// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes. + +#include "common.cuh" + + +#if CUDART_VERSION >= 11080 + +static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { + int ret = 0; + +#ifdef INT8_MMA_AVAILABLE + asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" + : "=r"(ret) : "r"(x)); +#else + GGML_UNUSED(x); + NO_DEVICE_CODE; +#endif // defined(INT8_MMA_AVAILABLE) + return ret; +} + +#else + +static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { + // Imagine transposing row-major matrix to column-major matrix. + const int src_i_low = 2 * (threadIdx.x % 4); + const int src_i_high = src_i_low + 1; + const int src_j = threadIdx.x / 4; + + const int src_laneid_low = src_i_low * 4 + src_j / 2; + const int src_laneid_high = src_i_high * 4 + src_j / 2; + + const int shift_low = ((src_j + 0) % 2) * 16; + const int shift_high = ((src_j + 1) % 2) * 16; + + const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF; + const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000; + + return ret_low | ret_high; +} + +#endif // CUDART_VERSION >= 11080 + +static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) { + half2 ret; + *((int *) &ret) = ggml_cuda_movmatrix(*((const int *) &x)); + return ret; +} + +namespace ggml_cuda_mma { + + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr int ne = I * J / WARP_SIZE; + T x[ne] = {0}; + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && (J == 4 || J == 8)) { + return threadIdx.x / 4; + } else if constexpr (I == 16 && J == 8) { + return (l / 2) * 8 + threadIdx.x / 4; + } else if constexpr (I == 16 && J == 16) { + return ((l / 2) % 2) * 8 + threadIdx.x / 4; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 4) { + return threadIdx.x % 4; + } else if constexpr (I == 8 && J == 8) { + return 4 * l + threadIdx.x % 4; + } else if constexpr (I == 16 && J == 8) { + return 2 * (threadIdx.x % 4) + l % 2; + } else if constexpr (I == 16 && J == 16) { + return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + }; + + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr int ne = I * J / WARP_SIZE; + half2 x[ne] = {{0.0f, 0.0f}}; + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && J == 8) { + return threadIdx.x / 4; + } else if constexpr (I == 16 && J == 4) { + return l * 8 + threadIdx.x / 4; + } else if constexpr (I == 16 && J == 8) { + return (l % 2) * 8 + threadIdx.x / 4; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 8) { + return l * 4 + threadIdx.x % 4; + } else if constexpr (I == 16 && J == 4) { + return threadIdx.x % 4; + } else if constexpr (I == 16 && J == 8) { + return (l / 2) * 4 + threadIdx.x % 4; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + }; + + template + static __device__ __forceinline__ tile get_half2(const tile & tile_float) { + tile ret; +#pragma unroll + for (int l0 = 0; l0 < tile_float.ne; l0 += 2) { + ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]); + } + return ret; + } + + static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) { + tile<8, 8, half2> ret; + ret.x[0] = ggml_cuda_movmatrix(t.x[0]); + ret.x[1] = ggml_cuda_movmatrix(t.x[1]); + + return ret; + } + + template + static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { +#pragma unroll + for (int l = 0; l < t.ne; ++l) { + t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; + } + } + + template + static __device__ __forceinline__ void load_ldmatrix( + tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) { +#ifdef INT8_MMA_AVAILABLE + int * xi = (int *) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J; + asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "=r"(xi[0]), "=r"(xi[1]) + : "l"(xs)); +#else + load_generic(t, xs0, stride); +#endif // INT8_MMA_AVAILABLE + } + + template + static __device__ __forceinline__ void load_ldmatrix( + tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) { +#ifdef INT8_MMA_AVAILABLE + int * xi = (int *) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride; + asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "=r"(xi[0]), "=r"(xi[1]) + : "l"(xs)); +#else + load_generic(xs0, stride); + GGML_UNUSED(t); +#endif // INT8_MMA_AVAILABLE + } + + template + static __device__ __forceinline__ void load_ldmatrix( + tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { +#ifdef INT8_MMA_AVAILABLE + int * xi = (int * ) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3]) + : "l"(xs)); +#else + load_generic(t, xs0, stride); +#endif // INT8_MMA_AVAILABLE + } + + template + static __device__ __forceinline__ void load_ldmatrix_trans( + tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { +#ifdef INT8_MMA_AVAILABLE + int * xi = (int * ) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3]) + : "l"(xs)); +#else + GGML_UNUSED(t); + GGML_UNUSED(xs0); + GGML_UNUSED(stride); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) { +#ifdef INT8_MMA_AVAILABLE +#if __CUDA_ARCH__ >= CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[0]), "r"(A.x[1]), "r"(B.x[0])); +#else + // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead: + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[0]), "+r"(D.x[1]) + : "r"(A.x[0]), "r"(B.x[0])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[1]), "r"(B.x[0])); +#endif // __CUDA_ARCH__ >= CC_AMPERE +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) { +#ifdef INT8_MMA_AVAILABLE +#if __CUDA_ARCH__ >= CC_AMPERE + asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[0]), "r"(A.x[1]), "r"(A.x[2]), "r"(A.x[3]), "r"(B.x[0]), "r"(B.x[1])); +#else + // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead: + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[0]), "+r"(D.x[1]) + : "r"(A.x[0]), "r"(B.x[0])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[1]), "r"(B.x[0])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[0]), "+r"(D.x[1]) + : "r"(A.x[2]), "r"(B.x[1])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[3]), "r"(B.x[1])); +#endif // __CUDA_ARCH__ >= CC_AMPERE +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { +#ifdef INT8_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; +#if __CUDA_ARCH__ >= CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); +#else + // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); +#endif // __CUDA_ARCH__ >= CC_AMPERE +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) { +#ifdef INT8_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; +#if __CUDA_ARCH__ >= CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" + : "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3])); +#else + // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); +#endif // __CUDA_ARCH__ >= CC_AMPERE +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { +#ifdef INT8_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; +#if __CUDA_ARCH__ >= CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); +#else + // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); +#endif // __CUDA_ARCH__ >= CC_AMPERE +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) { +#ifdef INT8_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; +#if __CUDA_ARCH__ >= CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3])); +#else + // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); +#endif // __CUDA_ARCH__ >= CC_AMPERE +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } +} diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu new file mode 100644 index 00000000..80108615 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 1, 8); +DECL_FATTN_MMA_F16_CASE(80, 1, 8); +DECL_FATTN_MMA_F16_CASE(96, 1, 8); +DECL_FATTN_MMA_F16_CASE(112, 1, 8); +DECL_FATTN_MMA_F16_CASE(128, 1, 8); +DECL_FATTN_MMA_F16_CASE(256, 1, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu new file mode 100644 index 00000000..66161c0a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 16, 1); +DECL_FATTN_MMA_F16_CASE(80, 16, 1); +DECL_FATTN_MMA_F16_CASE(96, 16, 1); +DECL_FATTN_MMA_F16_CASE(112, 16, 1); +DECL_FATTN_MMA_F16_CASE(128, 16, 1); +DECL_FATTN_MMA_F16_CASE(256, 16, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu new file mode 100644 index 00000000..ee88c72a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 16, 2); +DECL_FATTN_MMA_F16_CASE(80, 16, 2); +DECL_FATTN_MMA_F16_CASE(96, 16, 2); +DECL_FATTN_MMA_F16_CASE(112, 16, 2); +DECL_FATTN_MMA_F16_CASE(128, 16, 2); +DECL_FATTN_MMA_F16_CASE(256, 16, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu new file mode 100644 index 00000000..d888a5a4 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 16, 4); +DECL_FATTN_MMA_F16_CASE(80, 16, 4); +DECL_FATTN_MMA_F16_CASE(96, 16, 4); +DECL_FATTN_MMA_F16_CASE(112, 16, 4); +DECL_FATTN_MMA_F16_CASE(128, 16, 4); +DECL_FATTN_MMA_F16_CASE(256, 16, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu new file mode 100644 index 00000000..d93a2d08 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 2, 4); +DECL_FATTN_MMA_F16_CASE(80, 2, 4); +DECL_FATTN_MMA_F16_CASE(96, 2, 4); +DECL_FATTN_MMA_F16_CASE(112, 2, 4); +DECL_FATTN_MMA_F16_CASE(128, 2, 4); +DECL_FATTN_MMA_F16_CASE(256, 2, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu new file mode 100644 index 00000000..617464c9 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 2, 8); +DECL_FATTN_MMA_F16_CASE(80, 2, 8); +DECL_FATTN_MMA_F16_CASE(96, 2, 8); +DECL_FATTN_MMA_F16_CASE(112, 2, 8); +DECL_FATTN_MMA_F16_CASE(128, 2, 8); +DECL_FATTN_MMA_F16_CASE(256, 2, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu new file mode 100644 index 00000000..970d2b68 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 32, 1); +DECL_FATTN_MMA_F16_CASE(80, 32, 1); +DECL_FATTN_MMA_F16_CASE(96, 32, 1); +DECL_FATTN_MMA_F16_CASE(112, 32, 1); +DECL_FATTN_MMA_F16_CASE(128, 32, 1); +DECL_FATTN_MMA_F16_CASE(256, 32, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu new file mode 100644 index 00000000..65cd377c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 32, 2); +DECL_FATTN_MMA_F16_CASE(80, 32, 2); +DECL_FATTN_MMA_F16_CASE(96, 32, 2); +DECL_FATTN_MMA_F16_CASE(112, 32, 2); +DECL_FATTN_MMA_F16_CASE(128, 32, 2); +DECL_FATTN_MMA_F16_CASE(256, 32, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu new file mode 100644 index 00000000..f4a8bf34 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 4, 2); +DECL_FATTN_MMA_F16_CASE(80, 4, 2); +DECL_FATTN_MMA_F16_CASE(96, 4, 2); +DECL_FATTN_MMA_F16_CASE(112, 4, 2); +DECL_FATTN_MMA_F16_CASE(128, 4, 2); +DECL_FATTN_MMA_F16_CASE(256, 4, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu new file mode 100644 index 00000000..de191a8a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 4, 4); +DECL_FATTN_MMA_F16_CASE(80, 4, 4); +DECL_FATTN_MMA_F16_CASE(96, 4, 4); +DECL_FATTN_MMA_F16_CASE(112, 4, 4); +DECL_FATTN_MMA_F16_CASE(128, 4, 4); +DECL_FATTN_MMA_F16_CASE(256, 4, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu new file mode 100644 index 00000000..e8cb0e1b --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 4, 8); +DECL_FATTN_MMA_F16_CASE(80, 4, 8); +DECL_FATTN_MMA_F16_CASE(96, 4, 8); +DECL_FATTN_MMA_F16_CASE(112, 4, 8); +DECL_FATTN_MMA_F16_CASE(128, 4, 8); +DECL_FATTN_MMA_F16_CASE(256, 4, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu new file mode 100644 index 00000000..a532e962 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64, 1); +DECL_FATTN_MMA_F16_CASE(80, 64, 1); +DECL_FATTN_MMA_F16_CASE(96, 64, 1); +DECL_FATTN_MMA_F16_CASE(112, 64, 1); +DECL_FATTN_MMA_F16_CASE(128, 64, 1); +DECL_FATTN_MMA_F16_CASE(256, 64, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu new file mode 100644 index 00000000..bf25181a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 8, 1); +DECL_FATTN_MMA_F16_CASE(80, 8, 1); +DECL_FATTN_MMA_F16_CASE(96, 8, 1); +DECL_FATTN_MMA_F16_CASE(112, 8, 1); +DECL_FATTN_MMA_F16_CASE(128, 8, 1); +DECL_FATTN_MMA_F16_CASE(256, 8, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu new file mode 100644 index 00000000..378c132e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 8, 2); +DECL_FATTN_MMA_F16_CASE(80, 8, 2); +DECL_FATTN_MMA_F16_CASE(96, 8, 2); +DECL_FATTN_MMA_F16_CASE(112, 8, 2); +DECL_FATTN_MMA_F16_CASE(128, 8, 2); +DECL_FATTN_MMA_F16_CASE(256, 8, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu new file mode 100644 index 00000000..372641be --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 8, 4); +DECL_FATTN_MMA_F16_CASE(80, 8, 4); +DECL_FATTN_MMA_F16_CASE(96, 8, 4); +DECL_FATTN_MMA_F16_CASE(112, 8, 4); +DECL_FATTN_MMA_F16_CASE(128, 8, 4); +DECL_FATTN_MMA_F16_CASE(256, 8, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu new file mode 100644 index 00000000..9ff5968b --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 8, 8); +DECL_FATTN_MMA_F16_CASE(80, 8, 8); +DECL_FATTN_MMA_F16_CASE(96, 8, 8); +DECL_FATTN_MMA_F16_CASE(112, 8, 8); +DECL_FATTN_MMA_F16_CASE(128, 8, 8); +DECL_FATTN_MMA_F16_CASE(256, 8, 8); From 7cb6a76cd0ae54909cdbffa95f163c077827dfc5 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 4 May 2025 11:49:29 +0300 Subject: [PATCH 114/132] Update README.md --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index a04c1130..17d19645 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,13 @@ This repository is a fork of [llama.cpp](https://github.com/ggerganov/llama.cpp) with better CPU and hybrid GPU/CPU performance, new SOTA quantization types, first-class Bitnet support, better DeepSeek performance via MLA, FlashMLA, fused MoE operations and tensor overrides for hybrid GPU/CPU inference, row-interleaved quant packing, etc. +>[!IMPORTANT] +>The new GGUFs for DeepSeek-V3/R1/Lite do not work in this repository. This is due to the backwards incompatibe change in mainline `llama.cpp` that [added MLA support](https://github.com/ggml-org/llama.cpp/pull/12801) +>2.5 months after MLA was available here, and worked with the original DeepSeek GGUFs. Please use the original GGUF or, if you don't have one, convert the HF safetnosrs using the Python conversion scrip in this repository. + ## Latest News +* May 4 2025: Significant token generation performance improvement on CUDA with Flash Attention for GQA models. For details and benchmarks see PR #370 * April 29 2025: Qwen3 support added * April 26 2025: GLM-4 support added * April 26 2025: Command-A support added From 13281282986fb6783d0d7d64b3610bfb7085e749 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 4 May 2025 12:06:47 +0300 Subject: [PATCH 115/132] Update README.md --- README.md | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 17d19645..15f4d31d 100644 --- a/README.md +++ b/README.md @@ -12,42 +12,42 @@ This repository is a fork of [llama.cpp](https://github.com/ggerganov/llama.cpp) ## Latest News -* May 4 2025: Significant token generation performance improvement on CUDA with Flash Attention for GQA models. For details and benchmarks see PR #370 +* May 4 2025: 🚀 Significant token generation performance improvement on CUDA with Flash Attention for GQA models. For details and benchmarks see [PR #370](https://github.com/ikawrakow/ik_llama.cpp/pull/370) * April 29 2025: Qwen3 support added * April 26 2025: GLM-4 support added * April 26 2025: Command-A support added * April 22 2025: Support for the latest Microsoft Bitnet model added * April 21 2025: ik_llama.cpp builds and runs successfully on Android (using termux) -* April 17 2025: Better CPU Flash Attention token generation performance +* April 17 2025: 🚀 Better CPU Flash Attention token generation performance * April 13 2025: `IQ1_M` quantization improvements * April 10 2025: LLaMA-4 support added * April 7 2025: `IQ2_XS` quantization improvements -* April 3 2025: Much faster MoE implementation on Metal +* April 3 2025: 🚀 Much faster MoE implementation on Metal * April 1 2025: Quantization improvements for `Q2_K, Q4_K, Q5_K, Q4_1, Q5_1` * March 28 2025: Quantization imrovements for `Q4_0, Q5_0, Q6_0, Q3_K, Q6_K, IQ4_XS, IQ4_NL` -* March 25 2025: Better MoE performance on CUDA -* March 23 2025: Better batched processing speed for DeepSeek models +* March 25 2025: 🚀 Better MoE performance on CUDA +* March 23 2025: 🚀 Better batched processing speed for DeepSeek models * March 22 2025: Gemma3 support added -* March 21 2025: FlashMLA-3: fastest CPU-only inference for DeepSeek models -* March 18 2025: reduce compute buffer size -* March 17 2025: FlashMLA-2 performance improvements +* March 21 2025: 🚀 FlashMLA-3: fastest CPU-only inference for DeepSeek models +* March 18 2025: Reduce compute buffer size +* March 17 2025: 🚀 FlashMLA-2 performance improvements * March 12 2025: Allow `Q8_0` KV cache with FlashMLA-2 on CUDA -* March 10 2025: Better TG performance for MoE models on CUDA -* March 9 2025: FlashMLA on CUDA -* March 8 2025: Faster FlashMLA CPU implementation +* March 10 2025: 🚀 Better TG performance for MoE models on CUDA +* March 9 2025: 🚀 FlashMLA on CUDA +* March 8 2025: 🚀 Faster FlashMLA CPU implementation * March 7 2025: Custom quantization mixes using regular expressions -* March 5 2025: FlashMLA on CUDA -* March 3 2025: Introducing FlashMLA - MLA with Flash Attention +* March 5 2025: 🚀 FlashMLA on CUDA +* March 3 2025: 🚀 Introducing FlashMLA - MLA with Flash Attention * March 1 2025: Smart Expert Reduction for faster DeepSeek inference * Feb 27 2025: MLA without transposed cache -* Feb 25 2025: tensor overrides for better control where model weights are stored (GPU or CPU) -* Feb 23 2025: fused FFN ops for faster MoE inference +* Feb 25 2025: Tensor overrides for better control where model weights are stored (GPU or CPU) +* Feb 23 2025: 🚀 Fused FFN ops for faster MoE inference * Feb 23 2025: `sweep-bench` - better performance benchmarking -* Feb 20 2025: fast GEMM/GEMV for `IQ1_S` +* Feb 20 2025: 🚀 Fast GEMM/GEMV for `IQ1_S` * Feb 19 2025: `Q8_KV` - new type for 8-bit KV-cache quantization -* Feb 13 2025: allow `Q8_0` quantized cache with MLA -* Feb 11 2025: Flash Attention support for DeepSeek models -* Feb 9 2025: MLA for DeepSeek models +* Feb 13 2025: Allow `Q8_0` quantized cache with MLA +* Feb 11 2025: 🚀 Flash Attention support for DeepSeek models +* Feb 9 2025: 🚀 MLA for DeepSeek models * Jan 23 2025: DeepSeek-V3 support added ## Resources From f7c9a0f036951fecab32e056df954ebc54f8688f Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 4 May 2025 12:45:00 +0300 Subject: [PATCH 116/132] CUDA: MMQ for IQ4_KS (#374) * WIP * WIP: still getting illegal memory access * CUDA: MMQ for iq4_ks now works ~25% faster than dequantize+cuBLAS, ~10% slower than Q4_0 MMQ. --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda/mmq.cu | 8 +- ggml/src/ggml-cuda/mmq.cuh | 148 +++++++++++++----- .../template-instances/mmq-instance-iq4_ks.cu | 5 + ggml/src/ggml-cuda/vecdotq.cuh | 12 ++ 4 files changed, 133 insertions(+), 40 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks.cu diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 3b959182..67897a83 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -14,6 +14,7 @@ void ggml_cuda_op_mul_mat_q( const int64_t src1_padded_row_size, cudaStream_t stream) { const int64_t ne00 = src0->ne[0]; + const int64_t nb01 = src0->nb[1]; const int64_t ne10 = src1->ne[0]; const int64_t ne11 = src1->ne[1]; @@ -22,7 +23,6 @@ void ggml_cuda_op_mul_mat_q( const int64_t ne0 = dst->ne[0]; const int64_t row_diff = row_high - row_low; - const int64_t stride00 = ne00 / ggml_blck_size(src0->type); int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; @@ -31,7 +31,7 @@ void ggml_cuda_op_mul_mat_q( // nrows_dst == nrows of the matrix that the kernel writes into const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; - const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst}; + const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, nb01, src1_padded_row_size, src1_ncols, ne11, nrows_dst}; switch (src0->type) { case GGML_TYPE_Q4_0: @@ -91,6 +91,9 @@ void ggml_cuda_op_mul_mat_q( case GGML_TYPE_IQ4_NL: mul_mat_q_case(ctx, args, stream); break; + case GGML_TYPE_IQ4_KS: + mul_mat_q_case(ctx, args, stream); + break; default: GGML_ABORT("fatal error"); break; @@ -128,6 +131,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_KS: mmq_supported = true; break; default: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 753848d7..148697e2 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -82,6 +82,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { return MMQ_Q8_1_DS_LAYOUT_DS4; case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_KS: return MMQ_Q8_1_DS_LAYOUT_D4; default: GGML_ABORT("fatal error"); @@ -179,6 +180,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_IQ1_S : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_XS : return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_IQ4_NL : return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_KS : return MMQ_DP4A_TXS_Q8_0; default : return tile_x_sizes{0, 0, 0}; } } @@ -216,6 +218,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_IQ1_S : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_XS : return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_IQ4_NL : return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_KS : return MMQ_MMA_TILE_X_K_Q8_0; default : return 0; } } @@ -261,7 +264,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; + const block_q4_0 * bxi = (const block_q4_0 *)(x + i*stride) + kbx0 + kbx; const int qs0 = get_int_b2(bxi->qs, kqsx); #ifdef INT8_MMA_AVAILABLE @@ -283,7 +286,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; + const block_q4_0 * bxi = (const block_q4_0 *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; @@ -356,7 +359,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; + const block_q4_1 * bxi = (const block_q4_1 *)(x + i*stride) + kbx0 + kbx; const int qs0 = get_int_b4(bxi->qs, kqsx); #ifdef INT8_MMA_AVAILABLE @@ -378,7 +381,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; + const block_q4_1 * bxi = (const block_q4_1 *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; @@ -451,7 +454,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx; + const block_q5_0 * bxi = (const block_q5_0 *)(x + i*stride) + kbx0 + kbx; const int ql = get_int_b2(bxi->qs, kqsx); const int qh = get_int_b2(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0)); @@ -490,7 +493,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; + const block_q5_0 * bxi = (const block_q5_0 *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; @@ -523,7 +526,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx; + const block_q5_1 * bxi = (const block_q5_1 *)(x + i*stride) + kbx0 + kbx; const int ql = get_int_b4(bxi->qs, kqsx); const int qh = get_int_b4(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1)); @@ -560,7 +563,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; + const block_q5_1 * bxi = (const block_q5_1 *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; @@ -593,7 +596,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q6_0 * bxi = (const block_q6_0 *) x + kbx0 + i*stride + kbx; + const block_q6_0 * bxi = (const block_q6_0 *)(x + i*stride) + kbx0 + kbx; const int ql = get_int_b2(bxi->qs, kqsx); const int qh = get_int_b2(bxi->qh, kqsx%2) >> 4*(kqsx/2); @@ -623,7 +626,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q6_0 * bxi = (const block_q6_0 *) x + kbx0 + i*stride + kbxd; + const block_q6_0 * bxi = (const block_q6_0 *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; @@ -656,7 +659,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; + const block_q8_0 * bxi = (const block_q8_0 *)(x + i*stride) + kbx0 + kbx; #ifdef INT8_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx); @@ -678,7 +681,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; + const block_q8_0 * bxi = (const block_q8_0 *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; @@ -1044,7 +1047,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride; + const block_q2_K * bxi = (const block_q2_K *)(x + i*stride) + kbx0; const int x_ql_0 = get_int_b2(bxi->qs, kqsx); @@ -1275,7 +1278,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + const block_q3_K * bxi = (const block_q3_K *)(x + i*stride) + kbx0; const int x_ql_0 = get_int_b2(bxi->qs, kqsx); const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2))); @@ -1305,7 +1308,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + const block_q3_K * bxi = (const block_q3_K *)(x + i*stride) + kbx0; const int ksc = threadIdx.x % (WARP_SIZE/8); @@ -1341,7 +1344,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + const block_q3_K * bxi = (const block_q3_K *)(x + i*stride) + kbx0; x_df[i] = bxi->d; } @@ -1412,7 +1415,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0; const int qs0 = get_int_b4(bxi->qs, threadIdx.x); #ifdef INT8_MMA_AVAILABLE @@ -1433,7 +1436,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0; const int * scales = (const int *) bxi->scales; const int ksc = threadIdx.x % (WARP_SIZE/16); @@ -1462,7 +1465,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0; x_dm[i] = bxi->dm; } @@ -1475,7 +1478,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8); + const block_q4_K * bxi = (const block_q4_K *)(x + i*stride) + kbx0 + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8); const int * scales = (const int *) bxi->scales; @@ -1541,7 +1544,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0; const int ky = QR5_K*threadIdx.x; const int ql = get_int_b4(bxi->qs, threadIdx.x); @@ -1574,7 +1577,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0; const int * scales = (const int *) bxi->scales; const int ksc = threadIdx.x % (WARP_SIZE/16); @@ -1603,7 +1606,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0; x_dm[i] = bxi->dm; } @@ -1616,7 +1619,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + const block_q5_K * bxi = (const block_q5_K *)(x + i*stride) + kbx0; const int * scales = (const int *) bxi->scales; @@ -1683,7 +1686,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; + const block_q6_K * bxi = (const block_q6_K *)(x + i*stride) + kbx0; const int ql = get_int_b2(bxi->ql, threadIdx.x); const int ql0 = (ql >> 0) & 0x0F0F0F0F; @@ -1716,7 +1719,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd; + const block_q6_K * bxi = (const block_q6_K *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d; @@ -1733,7 +1736,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4; + const block_q6_K * bxi = (const block_q6_K *)(x + i*stride) + kbx0 + (threadIdx.x % (WARP_SIZE/8)) / 4; #ifdef INT8_MMA_AVAILABLE x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8)); @@ -1908,7 +1911,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx; + const block_iq4_nl * bxi = (const block_iq4_nl *)(x + i*stride) + kbx0 + kbx; const int aux_q4 = get_int_b2(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4); @@ -1933,7 +1936,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; + const block_iq4_nl * bxi = (const block_iq4_nl *)(x + i*stride) + kbx0 + kbxd; #ifdef INT8_MMA_AVAILABLE x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); @@ -1965,7 +1968,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride; + const block_iq2_xxs * bxi = (const block_iq2_xxs *)(x + i*stride) + kbx0; const int q2 = get_int_b2(bxi->qs, 2*kqsx+0); const uint8_t * aux8 = (const uint8_t *) &q2; @@ -2023,7 +2026,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride; + const block_iq2_xs * bxi = (const block_iq2_xs *)(x + i*stride) + kbx0; const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); const uint16_t * q2 = (const uint16_t *) &q2_packed; @@ -2079,7 +2082,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride; + const block_iq2_s * bxi = (const block_iq2_s *)(x + i*stride) + kbx0; const int qs_packed = get_int_b2(bxi->qs, kqsx); const uint8_t * qs = (const uint8_t *) &qs_packed; @@ -2142,7 +2145,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride; + const block_iq3_xxs * bxi = (const block_iq3_xxs *)(x + i*stride) + kbx0; const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); const uint8_t * q3 = (const uint8_t *) &q3_packed; @@ -2198,7 +2201,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride; + const block_iq3_s * bxi = (const block_iq3_s *)(x + i*stride) + kbx0; const int2 qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); const uint8_t * qs = (const uint8_t *) &qs_packed; @@ -2261,7 +2264,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride; + const block_iq1_s * bxi = (const block_iq1_s *)(x + i*stride) + kbx0; const int qs_packed = get_int_b2(bxi->qs, kqsx); const uint8_t * qs = (const uint8_t *) &qs_packed; @@ -2318,7 +2321,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride + kbx; + const block_iq4_xs * bxi = (const block_iq4_xs *)(x + i*stride) + kbx0 + kbx; const int aux_q4 = get_int_b4(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4); @@ -2340,7 +2343,7 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride; + const block_iq4_xs * bxi = (const block_iq4_xs *)(x + i*stride) + kbx0; const float d = __half2float(bxi->d); @@ -2355,6 +2358,64 @@ template static __device__ __forceinlin } } +template static __device__ __forceinline__ void load_tiles_iq4_ks( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kbx = 0; // threadIdx.x / QI4_XS + const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_ks * bxi = (const block_iq4_ks *)(x + i*stride + sizeof(float)) + kbx0 + kbx; + + auto values = iq4k_values + ((bxi->scales[kqsx/4] & 1) << 4); + const int aux_q4 = get_int_b4(bxi->qs, kqsx); + const int2 v = get_int_from_table_16(aux_q4, values); + const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; +#endif // INT8_MMA_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const float * dptr = (const float *)(x + i*stride); + const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0; + const int ls = (bxi->scales[threadIdx.x % 8] & 254) - 127; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * ls; +#else + x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * ls; +#endif // INT8_MMA_AVAILABLE + } +} + template static __device__ __forceinline__ void mmq_write_back_dp4a( const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) { @@ -2576,6 +2637,14 @@ struct mmq_type_traits { static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_ks; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + template static __device__ void mul_mat_q_process_tile( const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup, @@ -2608,7 +2677,7 @@ static __device__ void mul_mat_q_process_tile( const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int)); for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) { - load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01); + load_tiles(x + stride01*it*mmq_y, tile_x, kb0, tile_x_max_i, stride01); { const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int)); @@ -2889,6 +2958,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a mul_mat_q_stream_k_fixup<<>> (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x); + } else { constexpr bool need_check = true; @@ -2897,6 +2967,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a mul_mat_q_stream_k_fixup<<>> (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x); + } } @@ -3010,6 +3081,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S); extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_KS); // ------------------------------------------------------------------------------------------------------------------------- diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks.cu new file mode 100644 index 00000000..940c2da8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_ks.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ4_KS); diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index e9af29b9..cae5e04f 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -1131,6 +1131,18 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) { return make_int2(*((const int *) &val0_8), *((const int *) &val1_8)); } +static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * values) { + const int q0_32 = (q4 >> 0) & 0x0F0F0F0F; + const int8_t * q0_8 = (const int8_t *) &q0_32; + const char4 val0_8 = make_char4(values[q0_8[0]], values[q0_8[1]], values[q0_8[2]], values[q0_8[3]]); + + const int q1_32 = (q4 >> 4) & 0x0F0F0F0F; + const int8_t * q1_8 = (const int8_t *) &q1_32; + const char4 val1_8 = make_char4(values[q1_8[0]], values[q1_8[1]], values[q1_8[2]], values[q1_8[3]]); + + return make_int2(*((const int *) &val0_8), *((const int *) &val1_8)); +} + #define VDR_IQ4_NL_Q8_1_MMVQ 2 #define VDR_IQ4_NL_Q8_1_MMQ 4 From e3fec17347d5e85c262f857aab82fa1af3c76304 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 5 May 2025 08:39:10 +0300 Subject: [PATCH 117/132] Fix DeepSeek FA (#382) Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda/fattn.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index fce66c20..ea52fa02 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -521,6 +521,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // K and V head sizes. if (K->ne[0] != V->ne[0]) { ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); + return; } ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); From 6c23618ca5d680bd00f06a143dc4a1b386c827e3 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 6 May 2025 08:48:11 +0300 Subject: [PATCH 118/132] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 15f4d31d..5054a173 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,8 @@ This repository is a fork of [llama.cpp](https://github.com/ggerganov/llama.cpp) with better CPU and hybrid GPU/CPU performance, new SOTA quantization types, first-class Bitnet support, better DeepSeek performance via MLA, FlashMLA, fused MoE operations and tensor overrides for hybrid GPU/CPU inference, row-interleaved quant packing, etc. >[!IMPORTANT] ->The new GGUFs for DeepSeek-V3/R1/Lite do not work in this repository. This is due to the backwards incompatibe change in mainline `llama.cpp` that [added MLA support](https://github.com/ggml-org/llama.cpp/pull/12801) ->2.5 months after MLA was available here, and worked with the original DeepSeek GGUFs. Please use the original GGUF or, if you don't have one, convert the HF safetnosrs using the Python conversion scrip in this repository. +>The new GGUFs for DeepSeek-V3/R1/Lite do not work in this repository. This is due to the backwards incompatible change in mainline `llama.cpp` that [added MLA support](https://github.com/ggml-org/llama.cpp/pull/12801) +>2.5 months after MLA was available here, and worked with the original DeepSeek GGUFs. Please use the original GGUF or, if you don't have one, convert the HF safetensors using the Python conversion script in this repository. ## Latest News From 090eae4d693e7d09bae2d86b612c941dbf5c9a96 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 7 May 2025 10:33:27 +0300 Subject: [PATCH 119/132] Fix build for Xeon Gold 6226R (#390) Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_mul_mat.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 6adc43cf..136174f0 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1389,7 +1389,7 @@ static const uint32_t iq1s_grid_us[2048] = { }; #endif -#ifndef HAVE_FANCY_SIMD +#if !(defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__) const uint64_t keven_signs[128] = { 0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff, 0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff, @@ -7574,7 +7574,7 @@ struct DequantizerIQ1BN { _mm256_set_epi64x(0x0300010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100), }; const __m256i m3 = _mm256_set1_epi16(3); -#ifdef HAVE_FANCY_SIMD +#if defined HAVE_FANCY_SIMD && defined __AVX512VBMI__ const __m256i bmask = _mm256_set_epi8(62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); #endif @@ -7585,7 +7585,7 @@ struct DequantizerIQ1BN { auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[1]), mult[1]), m3); auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[2]), mult[2]), m3); auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[3]), mult[3]), m3); -#ifdef HAVE_FANCY_SIMD +#if defined HAVE_FANCY_SIMD && defined __AVX512VBMI__ v1 = _mm256_permutex2var_epi8(val1, bmask, val2); v2 = _mm256_permutex2var_epi8(val3, bmask, val4); #else @@ -7866,7 +7866,7 @@ struct DequantizerIQ3S final : public BaseDequantizer { }; struct EvenSignHelper { -#ifdef HAVE_FANCY_SIMD +#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__ union sbits_t { __m128i vec; __mmask32 mask[4]; @@ -7931,7 +7931,7 @@ struct DequantizerIQ3XXS final : public BaseDequantizer { } IQK_ALWAYS_INLINE void sign_2_values(const uint16_t * signs, __m256i * values) const { -#ifdef HAVE_FANCY_SIMD +#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__ esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(signs[2] | (signs[3] << 16)), _mm_set1_epi32(signs[0] | (signs[1] << 16))), values); #else esh.sign_value(signs[0] | (signs[1] << 16), values[0]); @@ -8106,7 +8106,7 @@ struct DequantizerIQ2XS final : public BaseDequantizer { value = _mm256_sign_epi8(value, _mm256_or_si256(signs, mone)); } inline void sign_values(const __m256i& data, __m256i * values) const { -#ifdef HAVE_FANCY_SIMD +#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__ auto partial_bits = _mm256_cvtepi16_epi8(_mm256_srli_epi16(data, 9)); auto pcnt = _mm_popcnt_epi8(partial_bits); auto full_bits = _mm_or_si128(partial_bits, _mm_slli_epi16(_mm_and_si128(pcnt, _mm_set1_epi8(1)), 7)); @@ -8156,7 +8156,7 @@ struct DequantizerIQ2XS final : public BaseDequantizer { constexpr static int minv = 43; SimpleBits bits; -#ifndef HAVE_FANCY_SIMD +#if !(defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__) Helper helper; #endif const __m256i idx_mask = _mm256_set1_epi16(511); @@ -8201,7 +8201,7 @@ struct DequantizerIQ2XXS final : public BaseDequantizer { } IQK_ALWAYS_INLINE void sign_values(const uint32_t * aux32, __m256i * values) const { -#ifdef HAVE_FANCY_SIMD +#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__ esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[3]), _mm_set1_epi32(aux32[1])), values+0); esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[7]), _mm_set1_epi32(aux32[5])), values+2); #else From 8a5c0410e127bd7d5221dc1d7354b1edff2385c0 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 7 May 2025 12:06:49 +0300 Subject: [PATCH 120/132] Fix DeepSeek q8_0 cache (#391) Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_flash_attn.cpp | 2 +- ggml/src/iqk/iqk_mul_mat.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/iqk/iqk_flash_attn.cpp b/ggml/src/iqk/iqk_flash_attn.cpp index fd0d5dd0..610f18b7 100644 --- a/ggml/src/iqk/iqk_flash_attn.cpp +++ b/ggml/src/iqk/iqk_flash_attn.cpp @@ -81,7 +81,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float int int_type_k = int_type_k_in; auto work_buffer = work_buffer_in; - if (neq1 >= 8 || rk2 >= 8) { + if (neq1 >= 8 || (rk2 >= 8 && nek2 > 1)) { uint64_t row_size = 0; work_buffer = iqk_repack_k(int_type_k, Dk, nek1, nek2, nek3, stride_k, nbk2, nbk3, k, work_buffer_in, ith, nth, int_type_k, row_size); if (int_type_k != int_type_k_in) { diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 136174f0..54792c12 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -18033,7 +18033,7 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k auto type_v = ggml_type(int_type_v); if (Dk == 576 && Dv == 512) { - GGML_ASSERT(type_k == type_v); + GGML_ASSERT(type_k == type_v || (type_k == GGML_TYPE_Q8_0_R8 && type_v == GGML_TYPE_Q8_0)); stride_q /= sizeof(float); // q stride as float return iqk_deepseek_helper<32>(type_k, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, M, S); From 17c6fc6b7303915e0fd74bee53c4de8d21746d52 Mon Sep 17 00:00:00 2001 From: Gaolingx <947770192@qq.com> Date: Wed, 7 May 2025 22:04:39 +0800 Subject: [PATCH 121/132] fix some MSVC build problem. (#392) * cmake: force MSVC compiler charset to utf-8 * build: apply MSVC /bigobj option to c/cpp files only * Update CMakeLists.txt --- CMakeLists.txt | 38 +++++++++++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 10049391..3e9c3cc0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -54,11 +54,11 @@ if (WIN32) endif() # force MSVC compiler charset to utf-8 -if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") - add_compile_options("$<$:/source-charset:utf-8>") - add_compile_options("$<$:/source-charset:utf-8>") - add_compile_options("$<$:/execution-charset:utf-8>") - add_compile_options("$<$:/execution-charset:utf-8>") +if (MSVC) + add_compile_options("$<$:/utf-8>") + add_compile_options("$<$:/utf-8>") + add_compile_options("$<$:/bigobj>") + add_compile_options("$<$:/bigobj>") endif() # @@ -123,6 +123,29 @@ llama_option_depr(WARNING LLAMA_SYCL GGML_SYCL) llama_option_depr(WARNING LLAMA_SYCL_F16 GGML_SYCL_F16) llama_option_depr(WARNING LLAMA_CANN GGML_CANN) +if (NOT MSVC) + if (LLAMA_SANITIZE_THREAD) + message(STATUS "Using -fsanitize=thread") + + add_compile_options(-fsanitize=thread) + link_libraries (-fsanitize=thread) + endif() + + if (LLAMA_SANITIZE_ADDRESS) + message(STATUS "Using -fsanitize=address") + + add_compile_options(-fsanitize=address -fno-omit-frame-pointer) + link_libraries (-fsanitize=address) + endif() + + if (LLAMA_SANITIZE_UNDEFINED) + message(STATUS "Using -fsanitize=undefined") + + add_compile_options(-fsanitize=undefined) + link_libraries (-fsanitize=undefined) + endif() +endif() + # # build the library # @@ -131,6 +154,11 @@ if (NOT TARGET ggml) add_subdirectory(ggml) # ... otherwise assume ggml is added by a parent CMakeLists.txt endif() + +# +# build the library +# + add_subdirectory(src) # From 30536ee369c829c7161b0170de550936b4548a6b Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 7 May 2025 17:38:22 +0300 Subject: [PATCH 122/132] FlashMLA-3 for DeepSeek models on CUDA (#386) * CUDA WIP: support for FlashMLA-3 * Much better The issue was that I did not change the number of warps used for 3D matrix multiplications (wk_b * kv_cache, MoE), so we ended up using 4 warps for TG. By going to 1 warp in these cases, we get a significant boost in TG performance (tested with DeepSeek-Lite) * Sadly, the previous commit was wrong * Finalizing * Also add these * Minor * Minor tweak --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda.cu | 5 + ggml/src/ggml-cuda/fattn-new-mma.cu | 1705 ++++++++++++++++++++++++++ ggml/src/ggml-cuda/fattn-new-mma.cuh | 3 + ggml/src/ggml-cuda/fattn.cu | 21 +- ggml/src/ggml-cuda/mmvq.cu | 107 +- 5 files changed, 1797 insertions(+), 44 deletions(-) create mode 100644 ggml/src/ggml-cuda/fattn-new-mma.cu create mode 100644 ggml/src/ggml-cuda/fattn-new-mma.cuh diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 1f62b882..ff6e064c 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3587,6 +3587,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons return (op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) || (op->src[1]->type == GGML_TYPE_Q8_0 && op->src[2]->type == GGML_TYPE_Q8_0); } + if (op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512) { + const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc; + int gqa = op->src[0]->ne[2]/op->src[1]->ne[2]; + return (new_mma_available(cc) && cc >= CC_AMPERE && op->src[3] && gqa%16 == 0); + } if (op->src[1]->ne[0] > 256) { return false; } diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu new file mode 100644 index 00000000..796d9c7b --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -0,0 +1,1705 @@ +// Adapted from https://github.com/ggml-org/llama.cpp/pull/13306 +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + +#include "fattn-new-mma.cuh" +#include "cp-async.cuh" +#include "mma_new.cuh" +#include "fattn-common.cuh" + +using namespace ggml_cuda_mma; + +typedef tile<16, 8, half2> tile_A; +typedef tile< 8, 8, half2> tile_B; +typedef tile<16, 8, half2> tile_B_16; +typedef tile<16, 8, float> tile_C_KQ; +typedef tile<16, 16, float> tile_C_KQ_16; +typedef tile<16, 4, half2> tile_C_VKQ; +typedef tile<16, 8, half2> tile_C_VKQ_16; + +// Config options for specific head sizes. +// Should not affect results, only speed/register pressure/shared memory use. +// +// nbatch_fa: number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators. +// nwarps_max: maximum number of warps per CUDA block, up to 8 warps in total can run per SM (given enough shared memory). +// Q_in_reg: whether the Q values should be kept permanently in registers. +// nstages_target: targeted number of pipeline stages for cp_async (if available), 0 means synchronous data loading. +// nbatch_K2: number of K half2 values in direction of DKQ to load in parallel. +// nbatch_V2: number of V half2 values in direction of DV to load in parallel. +// nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel. + +template +struct fattn_mma_f16_config; + +// +// The previous MMA version is better (faster) +// I'm keeping these around commented out for now, +// and only using the 576, 512 case. +// +//template <> +//struct fattn_mma_f16_config< 64, 64> { +// static constexpr int nbatch_fa = 64; +// static constexpr int nwarps_max = 4; +// static constexpr bool Q_in_reg = true; +// static constexpr int nstages_target = 2; +// static constexpr int nbatch_K2 = 32; +// static constexpr int nbatch_V2 = 32; +// static constexpr int nbatch_combine = 32; +//}; +// +//template <> +//struct fattn_mma_f16_config< 80, 80> { +// static constexpr int nbatch_fa = 64; +// static constexpr int nwarps_max = 4; +// static constexpr bool Q_in_reg = true; +// static constexpr int nstages_target = 2; +// static constexpr int nbatch_K2 = 40; +// static constexpr int nbatch_V2 = 40; +// static constexpr int nbatch_combine = 40; +//}; +// +//template <> +//struct fattn_mma_f16_config< 96, 96> { +// static constexpr int nbatch_fa = 64; +// static constexpr int nwarps_max = 4; +// static constexpr bool Q_in_reg = true; +// static constexpr int nstages_target = 2; +// static constexpr int nbatch_K2 = 48; +// static constexpr int nbatch_V2 = 48; +// static constexpr int nbatch_combine = 48; +//}; +// +//template <> +//struct fattn_mma_f16_config<112, 112> { +// static constexpr int nbatch_fa = 64; +// static constexpr int nwarps_max = 4; +// static constexpr bool Q_in_reg = true; +// static constexpr int nstages_target = 2; +// static constexpr int nbatch_K2 = 56; +// static constexpr int nbatch_V2 = 56; +// static constexpr int nbatch_combine = 56; +//}; +// +//template <> +//struct fattn_mma_f16_config<128, 128> { +// static constexpr int nbatch_fa = 64; +// static constexpr int nwarps_max = 4; +// static constexpr bool Q_in_reg = true; +// static constexpr int nstages_target = 2; +// static constexpr int nbatch_K2 = 64; +// static constexpr int nbatch_V2 = 64; +// static constexpr int nbatch_combine = 64; +//}; +// +//template <> +//struct fattn_mma_f16_config<192, 128> { +// static constexpr int nbatch_fa = 64; +// static constexpr int nwarps_max = 4; +// static constexpr bool Q_in_reg = true; +// static constexpr int nstages_target = 2; +// static constexpr int nbatch_K2 = 96; +// static constexpr int nbatch_V2 = 64; +// static constexpr int nbatch_combine = 64; +//}; +// +//template <> +//struct fattn_mma_f16_config<256, 256> { +// static constexpr int nbatch_fa = 32; +// static constexpr int nwarps_max = 4; +// static constexpr bool Q_in_reg = true; +// static constexpr int nstages_target = 2; +// static constexpr int nbatch_K2 = 128; +// static constexpr int nbatch_V2 = 128; +// static constexpr int nbatch_combine = 128; +//}; + +template <> +struct fattn_mma_f16_config<576, 512> { + static constexpr int nbatch_fa = 32; + static constexpr int nwarps_max = 8; + static constexpr bool Q_in_reg = false; + static constexpr int nstages_target = 1; + static constexpr int nbatch_K2 = 160; + static constexpr int nbatch_V2 = 128; + static constexpr int nbatch_combine = 128; +}; + +template +static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) { + + // K/V data is loaded with decreasing granularity for D for better memory bandwidth. + // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. + + if constexpr (use_cp_async) { + const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV); + + constexpr int preload = 64; + constexpr int h2_per_chunk = 16/sizeof(half2); + + const int chunks_per_row = D2 / h2_per_chunk; + + int k0_start = 0; +#pragma unroll + for (int stride_k = WARP_SIZE; stride_k > WARP_SIZE/32; stride_k >>= 1) { + const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); + + if (k0_start == k0_stop) { + continue; + } + + const int stride_i = WARP_SIZE / stride_k; +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + threadIdx.x / stride_k; + + if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { + break; + } + +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + cp_async_cg_16(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk); + } + } + k0_start = k0_stop; + } + } else { + static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds"); +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k); + const int k0_stop = D2 - D2 % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; + + if (k0_start == k0_stop || k0_stop <= 0) { + continue; + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { + break; + } + +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + tile_KV[i*stride_tile + k] = KV[i*stride_KV + k]; + } + } + } + } +} + +template +static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( + const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) { + static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter"); + + if constexpr (use_cp_async) { + constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64; + constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa; + constexpr int stride_j = nwarps * cols_per_warp; + + const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask); + +#pragma unroll + for (int j0 = 0; j0 < ncols1; j0 += stride_j) { + const int j = j0 + threadIdx.y*cols_per_warp + + (nbatch_fa == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/cols_per_warp)); + + if (j0 + stride_j > ncols1 && j >= ncols1) { + break; + } + + const int i = 4 * (threadIdx.x % (nbatch_fa/8)); + + cp_async_cg_16(tile_mask_32 + j*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i); + } + return; + } + + constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa; + constexpr int stride_j = nwarps * cols_per_warp; +#pragma unroll + for (int j0 = 0; j0 < ncols1; j0 += stride_j) { + const int j = j0 + threadIdx.y*cols_per_warp + (nbatch_fa == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/cols_per_warp)); + + if (j0 + stride_j > ncols1 && j >= ncols1) { + break; + } + + const int i = nbatch_fa == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/cols_per_warp); + + tile_mask[j*(nbatch_fa/2 + 4) + i] = mask_h2[j*stride_mask + i]; + } +} + +template +static __device__ __forceinline__ void flash_attn_ext_f16_iter( + const float2 * const __restrict__ Q_f2, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half2 * const __restrict__ mask_h2, + float2 * const __restrict__ dstk, + float2 * const __restrict__ dstk_fixup, + const float scale, + const float slope, + const float logit_softcap, + const int ne01, + const int ne02, + const int stride_K, + const int stride_V, + const int stride_mask, + const int jt, + half2 * const __restrict__ tile_Q, + half2 * const __restrict__ tile_K, + half2 * const __restrict__ tile_V, + half2 * const __restrict__ tile_mask, + const tile_B * const __restrict__ Q_B, + tile_C_VKQ * const __restrict__ VKQ_C, + float * const __restrict__ KQ_max, + float * const __restrict__ KQ_rowsum, + const int kb0) { +#ifdef INT8_MMA_AVAILABLE + typedef fattn_mma_f16_config c; + +#ifdef CP_ASYNC_AVAILABLE + constexpr int nstages = c::nstages_target; +#else + constexpr int nstages = 0; +#endif // CP_ASYNC_AVAILABLE + + constexpr int cols_per_warp = ntiles * tile_B::I; + constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = c::nbatch_K2 + 4; + constexpr int stride_tile_V = c::nbatch_V2 + 4; + + const int k_VKQ_0 = kb0 * c::nbatch_fa; + tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles]; + + // Use wide variants of tiles if ntiles >= 2. + tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; + tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; + tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; + + if constexpr (nstages > 1) { + static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); + constexpr bool use_cp_async = true; + cp_async_wait_all(); + __syncthreads(); + flash_attn_ext_f16_load_tile + (V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V); + } else { + constexpr bool use_cp_async = nstages == 1; + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); + } + } + +#pragma unroll + for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) { + const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2; + const int k0_diff = k0_stop - k0_start; + + if (nstages <= 1) { + constexpr bool use_cp_async = nstages == 1; + flash_attn_ext_f16_load_tile + (K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K); + if (use_cp_async) { + cp_async_wait_all(); + } + __syncthreads(); + } + + // Calculate tile of KQ: + if constexpr (c::Q_in_reg) { +#pragma unroll + for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; +#pragma unroll + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { + tile_A K_A; + load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); + if (ntiles == 1) { + mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + // Wide version of KQ_C is column-major => swap A and B. + mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A); + } + } + } + } + } else { + static_assert(ntiles == 2, "ntiles != 2 not implemented"); +#pragma unroll + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { + load_ldmatrix(Q_B_16[0], tile_Q + (threadIdx.y / np)*(tile_B_16::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); + +#pragma unroll + for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; + + tile_A K_A; + load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); + + // Wide version of KQ_C is column-major => swap A and B. + mma(KQ_C_16[i_KQ_00/(np*tile_A::I)], Q_B_16[0], K_A); + } + } + } + + if (nstages <= 1) { + __syncthreads(); // Only needed if tile_K == tile_V. + } + } + + if (use_logit_softcap) { + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) { +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); + } + } + } + + float KQ_max_new[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + KQ_max_new[col] = KQ_max[col]; + } + float KQ_rowsum_add[cols_per_thread] = {0.0f}; + + if (ntiles == 1) { + if (ncols2 > 1 || mask_h2) { +#pragma unroll + for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) { + const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + const int i = i0 + tile_C_KQ::get_i(l); + const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2; + + KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope * + __half2float(((const half *) tile_mask)[j*(c::nbatch_fa + 8) + i]); + } + } + } + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]); + } + } + + // Values per KQ column are spread across 8 threads, does not need full warp reduce: +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = 16; offset >= 4; offset >>= 1) { + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + } + } + + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]); + + KQ_rowsum_add[l % 2] += KQ_C[k].x[l]; + } + } + } else { // ntiles > 1 + if (ncols2 > 1 || mask_h2) { +#pragma unroll + for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ_16::J) { + const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J; +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { +#pragma unroll + for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) { + const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2; + const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2; + + const float2 tmp = __half22float2(tile_mask[j*(c::nbatch_fa/2 + 4) + i]); + const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t; + KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x; + KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y; + } + } + } + } + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { +#pragma unroll + for (int l = 0; l < tile_C_KQ_16::ne; ++l) { + const int KQ_index = 2*t + (l/2) % 2; + KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]); + } + } + } + + // Values per KQ column are spread across 4 threads, does not need full warp reduce: +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = 2; offset >= 1; offset >>= 1) { + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + } + } + + static_assert(c::nbatch_fa % (np*tile_C_KQ_16::J) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { +#pragma unroll + for (int l = 0; l < tile_C_KQ_16::ne; ++l) { + const int KQ_index = 2*t + (l/2) % 2; + + KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]); + + KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l]; + } + } + } + } + + { + float KQ_max_scale[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]); + KQ_max[col] = KQ_max_new[col]; + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col]; + } + + if (ntiles == 1) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); +#pragma unroll + for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { +#pragma unroll + for (int l = 0; l < tile_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); +#pragma unroll + for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) { +#pragma unroll + for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { + VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + } + } + } + } + } + + // Convert KQ C tiles into B tiles for VKQ calculation: + tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles]; + tile_B_16 * B_16 = (tile_B_16 *) B; + static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size"); + if (ntiles == 1) { +#pragma unroll + for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) { + B[k] = get_transposed(get_half2(KQ_C[k])); + } + } else { + for (int k = 0; k < c::nbatch_fa/(np*2*tile_B_16::J); ++k) { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]); + } + } + } + + if (nstages > 1) { + // Preload K tile for next iteration: + constexpr bool use_cp_async = true; + cp_async_wait_all(); + __syncthreads(); + if (!last_iter) { + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask + (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile + (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K); + } + } + +#pragma unroll + for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) { + const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV; + const int i0_diff = i0_stop - i0_start; + + if (nstages == 1) { + constexpr bool use_cp_async = nstages == 1; + flash_attn_ext_f16_load_tile + (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); + if (use_cp_async) { + cp_async_wait_all(); + } + __syncthreads(); + } + + // Calculate VKQ tile: +#pragma unroll + for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += tile_C_VKQ::I) { + static_assert((c::nbatch_fa/2) % (np*tile_A::J) == 0, "bad loop size"); +#pragma unroll + for (int k00 = 0; k00 < c::nbatch_fa/2; k00 += np*tile_A::J) { + const int k0 = k00 + (threadIdx.y % np)*tile_A::J; + + tile_A A; + load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); + if (ntiles == 1) { + mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + // Wide version of VKQ_C is column-major => swap A and B. + mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A); + } + } + } + } + + if (nstages <= 1) { + __syncthreads(); // Only needed if tile_K == tile_V. + } + } +#else + GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); + GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); + GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); + GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); + GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); + GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B); + GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum); + GGML_UNUSED(kb0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE +} + +template +static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( + const float2 * const __restrict__ Q_f2, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half2 * const __restrict__ mask_h2, + float2 * const __restrict__ dstk, + float2 * const __restrict__ dstk_fixup, + const float scale, + const float slope, + const float logit_softcap, + const int ne01, + const int ne02, + const int stride_Q1, + const int stride_Q2, + const int stride_K, + const int stride_V, + const int stride_mask, + const int jt, + const int kb0_start, + const int kb0_stop) { +#ifdef INT8_MMA_AVAILABLE + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + typedef fattn_mma_f16_config c; + +#ifdef CP_ASYNC_AVAILABLE + constexpr int nstages = c::nstages_target; +#else + constexpr int nstages = 0; +#endif // CP_ASYNC_AVAILABLE + + constexpr int ncols = ncols1 * ncols2; + constexpr int cols_per_warp = ntiles * tile_B::I; + constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; + constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + + static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); + + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = c::nbatch_K2 + 4; + constexpr int stride_tile_V = c::nbatch_V2 + 4; + + constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; + + extern __shared__ half2 tile_Q[]; + half2 * tile_K = c::Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q; + half2 * tile_V = nstages > 1 ? tile_K + c::nbatch_fa * stride_tile_K : tile_K; + half2 * tile_mask = nstages > 1 ? tile_V + c::nbatch_fa * stride_tile_V : tile_V + c::nbatch_fa * stride_tile_KV_max; + + tile_B Q_B[(c::Q_in_reg ? DKQ/(2*tile_B::J) : 1) * ntiles]; + tile_C_VKQ VKQ_C[DV/tile_C_VKQ::I * ntiles]; + + tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; + tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; + + float KQ_rowsum[cols_per_thread] = {0.0f}; + float KQ_max[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + KQ_max[col] = -FLT_MAX/2.0f; + } + + // Load Q data into tile_Q, either temporarily or permanently. + // Q in registers is faster, but register pressure is the biggest bottleneck. + // The loading is done with decreasing granularity for D for better memory bandwidth. + const half2 scale_h2 = make_half2(scale, scale); +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k); + const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k); + const int stride_jc = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + continue; + } + +#pragma unroll + for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) { + const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) { + break; + } + + const int j = jc / ncols2; + const int c = jc % ncols2; + + if (jt*ncols1 + j < ne01) { +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k]; + tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y); + } + } else { +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f); + } + } + } + } + + __syncthreads(); + + if (c::Q_in_reg) { + const int j0 = (threadIdx.y / np) * cols_per_warp; + +#pragma unroll + for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) { + if (ntiles == 1) { + load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t], + tile_Q + (j0 + t*tile_B_16::I)*stride_tile_Q + k0, stride_tile_Q); + } + } + } + } + + __syncthreads(); + + // Preload mask and K data for first iteration when using cp_async with multiple stages: + if constexpr (nstages > 1) { + static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline"); + constexpr bool use_cp_async = true; + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask + (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile + (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K); + } + + // Iterate over ne11 == previous tokens: + for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { + constexpr bool last_iter = false; + flash_attn_ext_f16_iter + (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + } + { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. + constexpr bool last_iter = true; + flash_attn_ext_f16_iter + (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); + } + + // With multi-stage loading there is no __syncthreads at the end of the iter, + // there can be a race condition on shared memory access for combining/writing back results. + if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) { + __syncthreads(); + } + + // Finally, sum up partial KQ rowsums. + // The partial sums are spread across 8/4 threads each, does not need full reduce. + { + constexpr int offset_first = ntiles == 1 ? 16 : 2; + constexpr int offset_last = ntiles == 1 ? 4 : 1; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = offset_first; offset >= offset_last; offset >>= 1) { + KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE); + } + } + } + + // Combine VKQ accumulator values if np > 1. + // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. + // So also write VKQ accumulators to shared memory in column-major format if np == 1. + + constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4; + constexpr int tile_stride = nbatch_combine + 4; + static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine"); + + if constexpr (ntiles == 1) { + const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset + const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta + const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum + + if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) { + // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. + ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; + } + + __syncthreads(); + + if (np == 1) { + // No combination is needed, the meta data can be directly written from registers to VRAM. + if (needs_fixup && threadIdx.x < tile_B::I) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + if (is_fixup && threadIdx.x < tile_B::I) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + } + } else { + static_assert(ntiles == 2 || ntiles == 4, "bad ntiles"); + const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta + + (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0) + + tile_C_VKQ_16::get_i(threadIdx.x % 4); + const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum + + if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) { + // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. + ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; + } + + __syncthreads(); + + if (np == 1) { + // No combination is needed, the meta data can be directly written from registers to VRAM. + if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + } + } + + static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles"); + if (np > 1 && threadIdx.y % np == 0) { + // Combine the meta data for parallel warps via shared memory. + // Warps with threadIdx.y % np != 0 must NOT return early. + // All threads must return simultaneously to avoid race conditions with work on the next tile. + + constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1; + + const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); + float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2; + float2 meta[nmeta]; +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2]; + } + + float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps. +#pragma unroll + for (int imeta = 1; imeta < nmeta; ++imeta) { + KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x); + } +#pragma unroll + for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { + if (offset >= WARP_SIZE) { + continue; + } + KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); + } + + float KQ_cms[nmeta]; // KQ combine max scale per warp. +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn); + } + + float KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps. +#pragma unroll + for (int imeta = 1; imeta < nmeta; ++imeta) { + KQ_crs += KQ_cms[imeta]*meta[imeta].y; + } +#pragma unroll + for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { + if (offset >= WARP_SIZE) { + continue; + } + KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); + } + + // Write back combined meta data: +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) { + // Combined KQ max scale + rowsum. + meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs); + } + } + + // Combined KQ max + rowsum. + static_assert(cols_per_warp <= WARP_SIZE); + if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + } + if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + } + } + +#pragma unroll + for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) { + if (ntiles == 1) { + const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data +#pragma unroll + for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) { + const tile_B B = get_transposed(VKQ_C[(k00 + k0)/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. + +#pragma unroll + for (int l = 0; l < tile_B::ne; ++l) { + const int k = k0 + tile_B::get_j(l); + + tile_Q[jc_cwd*tile_stride + k] = B.x[l]; + } + } + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I; +#pragma unroll + for (int k0 = 0; k0 < nbatch_combine; k0 += tile_C_VKQ_16::J) { +#pragma unroll + for (int l = 0; l < tile_C_VKQ_16::ne; ++l) { + const int j = j0 + tile_C_VKQ_16::get_i(l); + const int k = k0 + tile_C_VKQ_16::get_j(l); + + tile_Q[j*tile_stride + k] = VKQ_C_16[(k00 + k0)/tile_C_VKQ_16::J * ntiles/2 + t].x[l]; + } + } + } + } + + __syncthreads(); + + if (np == 1 || threadIdx.y % np == 0) { + // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums. + // The values after that are for the partial results of the individual blocks. + float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2)); + +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k); + const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k); + const int stride_jc = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + continue; + } + +#pragma unroll + for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) { + const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) { + break; + } + + const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp; + + const int j_dst = jc_dst / ncols2; + const int c_dst = jc_dst % ncols2; + + if (!is_fixup && jt*ncols1 + j_dst >= ne01) { + continue; + } + + const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine; +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + float2 dstk_val = make_float2(0.0f, 0.0f); +#pragma unroll + for (int ip = 0; ip < np; ++ip) { + const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * tile_stride + 0]; + const float2 dstk_val_add = __half22float2(tile_Q[(jc_tile_K + ip*cols_per_warp) * tile_stride + k]); + dstk_val.x += dstk_val_add.x*KQ_crs; + dstk_val.y += dstk_val_add.y*KQ_crs; + } + + if (!needs_fixup && !is_fixup) { + const float KQ_rowsum_j = meta_j[1]; + dstk_val.x /= KQ_rowsum_j; + dstk_val.y /= KQ_rowsum_j; + } + + if (is_fixup) { + dstk_fixup_data[jc_dst*(DV/2) + k00 + k] = dstk_val; + } else { + dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(DV/2) + k00 + k] = dstk_val; + } + } + } + } + } + if (np > 1) { + __syncthreads(); + } + } +#else + GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); + GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); + GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1); + GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask); + GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE +} + +template +__launch_bounds__(nwarps*WARP_SIZE, 1) +static __global__ void flash_attn_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const float logit_softcap, + const uint32_t n_head_log2, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int nb21, + const int nb22, + const int nb23, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { +#if defined(INT8_MMA_AVAILABLE) + + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { + NO_DEVICE_CODE; + return; + } + + typedef fattn_mma_f16_config c; + + static_assert(FATTN_KQ_STRIDE % fattn_mma_f16_config::nbatch_fa == 0, "bad nbatch_fa"); + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + + const int stride_Q1 = nb01 / sizeof(float2); + const int stride_Q2 = nb02 / sizeof(float2); + const int stride_K = nb11 / sizeof(half2); + const int stride_V = nb21 / sizeof(half2); + const int stride_mask = nb31 / sizeof(half2); + + const int iter_k = ne11 / FATTN_KQ_STRIDE; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + + constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice. + + // kbc == k block continuous, current index in continuous ijk space. + int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + + // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. + // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). + // In the most general case >2 seams can fall into the same tile. + + // kb0 == k start index when in the output tile. + int kb0_start = kbc % iter_k; + int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); + while (kbc < kbc_stop && kb0_stop == iter_k) { + const int channel = kbc / (iter_k*iter_j); + const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + + const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); + const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); + const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; + float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; + + const int kb0_start_kernel = kb0_start * kb_niter; + const int kb0_stop_kernel = kb0_stop * kb_niter; + + constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. + if (kb0_start == 0) { + constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + } else { + constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + } + + kbc += iter_k; + kbc -= kbc % iter_k; + + kb0_start = 0; + kb0_stop = min(iter_k, kbc_stop - kbc); + } + + if (kbc >= kbc_stop) { + return; + } + + const int channel = kbc / (iter_k*iter_j); + const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + + const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); + const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape + const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; + float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; + + const int kb0_start_kernel = kb0_start * kb_niter; + const int kb0_stop_kernel = kb0_stop * kb_niter; + + constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. + constexpr bool needs_fixup = false; + flash_attn_ext_f16_process_tile + (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); +#else + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); + GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); + GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); + GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); + GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); + GGML_UNUSED(ne2); GGML_UNUSED(ne3); + NO_DEVICE_CODE; +#endif // defined(INT8_MMA_AVAILABLE) +} + +template // D == head size +__launch_bounds__(D, 1) +static __global__ void flash_attn_stream_k_fixup( + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { + constexpr int ncols = ncols1*ncols2; + + const int bidx0 = blockIdx.x; + const int j = blockIdx.y; + const int c = blockIdx.z; + const int jc = j*ncols2 + c; + const int tid = threadIdx.x; + + const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); + + const int iter_k = ne11 / FATTN_KQ_STRIDE; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + + const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + + const bool did_not_have_any_data = kbc0 == kbc0_stop; + const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; + const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0; + if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { + return; + } + + const int channel = kbc0 / (iter_k*iter_j); + const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k; + + if (jt*ncols1 + j >= ne01) { + return; + } + + dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid; + + // Load the partial result that needs a fixup: + float dst_val = 0.0f; + float max_val = 0.0f; + float rowsum = 0.0f; + { + dst_val = *dst; + + const float2 tmp = dst_fixup[bidx0*ncols + jc]; + max_val = tmp.x; + rowsum = tmp.y; + } + + // Iterate over previous blocks and compute the combined results. + // All CUDA blocks that get here must have a previous block that needs a fixup. + int bidx = bidx0 - 1; + int kbc_stop = kbc0; + while(true) { + const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + if (kbc == kbc_stop) { // Did not have any data. + bidx--; + kbc_stop = kbc; + continue; + } + + const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid]; + + const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc]; + + // Scale the current and new value accumulators depending on the max. values. + const float max_val_new = fmaxf(max_val, tmp.x); + + const float diff_val = max_val - max_val_new; + const float diff_add = tmp.x - max_val_new; + + const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f; + const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f; + + dst_val = scale_val*dst_val + scale_add*dst_add; + rowsum = scale_val*rowsum + scale_add*tmp.y; + + max_val = max_val_new; + + // If this block started in a previous tile we are done and don't need to combine additional partial results. + if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) { + break; + } + bidx--; + kbc_stop = kbc; + } + + // Write back final result: + *dst = dst_val / rowsum; +} + +template // D == head size +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(D, 1) +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void flash_attn_combine_results_new( + const float * __restrict__ VKQ_parts, + const float2 * __restrict__ VKQ_meta, + float * __restrict__ dst, + const int parallel_blocks) { + VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x; + VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x; + dst += D * gridDim.z*blockIdx.x; + + const int tid = threadIdx.x; + __builtin_assume(tid < D); + + extern __shared__ float2 meta[]; + if (tid < 2*parallel_blocks) { + ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid]; + } + + __syncthreads(); + + float kqmax = meta[0].x; + for (int l = 1; l < parallel_blocks; ++l) { + kqmax = max(kqmax, meta[l].x); + } + + float VKQ_numerator = 0.0f; + float VKQ_denominator = 0.0f; + for (int l = 0; l < parallel_blocks; ++l) { + const float diff = meta[l].x - kqmax; + float KQ_max_scale = expf(diff); + const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); + *((uint32_t *) &KQ_max_scale) &= ftz_mask; + + VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid]; + VKQ_denominator += KQ_max_scale * meta[l].y; + } + + dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator; +} + +template +void launch_fattn_new_mma( + ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, + const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE +) { + constexpr int ncols = ncols1 * ncols2; + + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + const ggml_tensor * mask = dst->src[3]; + + ggml_tensor * KQV = dst; + + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); + GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + + GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); + + GGML_ASSERT(Q->ne[3] == 1); + + ggml_cuda_pool & pool = ctx.pool(); + cudaStream_t main_stream = ctx.stream(); + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int nsm = ggml_cuda_info().devices[id].nsm; + + ggml_cuda_pool_alloc K_f16(pool); + ggml_cuda_pool_alloc V_f16(pool); + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + const char * K_data = (const char *) K->data; + size_t nb11 = K->nb[1]; + size_t nb12 = K->nb[2]; + size_t nb13 = K->nb[3]; + + const char * V_data = (const char *) V->data; + size_t nb21 = V->nb[1]; + size_t nb22 = V->nb[2]; + size_t nb23 = V->nb[3]; + + if (need_f16_K && K->type != GGML_TYPE_F16) { + K_f16.alloc(ggml_nelements(K)); + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); + to_fp16(K_data, K_f16.ptr, 1, ggml_nelements(K), main_stream); + K_data = (char *) K_f16.ptr; + + const size_t bs = ggml_blck_size(K->type); + const size_t ts = ggml_type_size(K->type); + + nb11 = nb11*bs*sizeof(half)/ts; + nb12 = nb12*bs*sizeof(half)/ts; + nb13 = nb13*bs*sizeof(half)/ts; + } + + if (need_f16_V && V->type != GGML_TYPE_F16) { + V_f16.alloc(ggml_nelements(V)); + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); + to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream); + V_data = (char *) V_f16.ptr; + + const size_t bs = ggml_blck_size(V->type); + const size_t ts = ggml_type_size(V->type); + + nb21 = nb21*bs*sizeof(half)/ts; + nb22 = nb22*bs*sizeof(half)/ts; + nb23 = nb23*bs*sizeof(half)/ts; + } + + int parallel_blocks = 1; + + const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); + const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; + + const dim3 block_dim(warp_size, nwarps, 1); + int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. + CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); + + dim3 blocks_num; + if (stream_k) { + // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. + const int max_blocks = max_blocks_per_sm*nsm; + const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks; + const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves); + + const int nblocks_stream_k = max_blocks; + + const bool use_stream_k = cc >= CC_ADA_LOVELACE || tiles_efficiency_percent < 75; + + blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total; + blocks_num.y = 1; + blocks_num.z = 1; + + dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float)); + } else { + GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0); + const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. + + // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave: + parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1); + + // parallel_blocks must not be larger than what the tensor size allows: + parallel_blocks = std::min(parallel_blocks, ntiles_KQ); + + // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects. + // Test whether parallel_blocks can be set to a higher value for better efficiency. + const int blocks_per_wave = nsm * max_blocks_per_sm; + int nwaves_best = 0; + int efficiency_percent_best = 0; + for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) { + const int nblocks_total = ntiles_total * parallel_blocks_test; + const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave; + const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave); + + // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead. + if (efficiency_percent_best >= 90 && nwaves > nwaves_best) { + break; + } + + if (efficiency_percent > efficiency_percent_best) { + nwaves_best = nwaves; + efficiency_percent_best = efficiency_percent; + parallel_blocks = parallel_blocks_test; + } + } + + blocks_num.x = ntiles_x; + blocks_num.y = parallel_blocks; + blocks_num.z = Q->ne[2]*Q->ne[3]; + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + } + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + + const uint32_t n_head = Q->ne[2]; + const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head)))); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + GGML_ASSERT(block_dim.x % warp_size == 0); + fattn_kernel<<>>( + (const char *) Q->data, + K_data, + V_data, + mask ? ((const char *) mask->data) : nullptr, + !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, + scale, max_bias, m0, m1, n_head_log2, logit_softcap, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + nb11, nb12, nb13, + nb21, nb22, nb23, + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if (stream_k) { + if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + const dim3 block_dim_combine(DV, 1, 1); + const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; + + flash_attn_stream_k_fixup + <<>> + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); + } + } else if (parallel_blocks > 1) { + const dim3 block_dim_combine(DV, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z); + const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); + + flash_attn_combine_results_new + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); + } + CUDA_CHECK(cudaGetLastError()); +} + + +template +void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + + typedef fattn_mma_f16_config c; + + constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2; + constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2; + constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine; + + const int nstages = cp_async_available(cc) ? c::nstages_target : 0; + + constexpr int ncols = ncols1 * ncols2; + constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp. + constexpr int cols_per_warp = ntiles * tile_B::I; + constexpr int nwarps_max_x = ncols / cols_per_warp; + constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I; + constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max; + + static_assert(DKQ % tile_B::J == 0, "bad DKQ"); + static_assert(DV % tile_A::J == 0, "bad DV"); + static_assert(ncols % cols_per_warp == 0, "bad ncols"); + + const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); + const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2); + const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); + + const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage; + + const size_t nbytes_shared_total = std::max(nbytes_shared_combine, c::Q_in_reg ? + std::max(nbytes_shared_Q, nbytes_shared_KV + nbytes_shared_mask) : + nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask); + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16; + +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); + shared_memory_limit_raised[id] = true; + } +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16; + +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); + shared_memory_limit_raised[id] = true; + } +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + } + + launch_fattn_new_mma + (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true); +} + +template +static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + if constexpr (ncols2 <= 8) { + if (Q->ne[1] <= 8/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + } + + if (Q->ne[1] <= 16/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + + if (Q->ne[1] <= 32/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); +} + +template +static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * mask = dst->src[3]; + + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + const bool use_gqa_opt = mask && max_bias == 0.0f; + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + + if (use_gqa_opt && gqa_ratio % 8 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 4 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 2 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); +} + +void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + + GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512); + + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + const bool use_gqa_opt = mask && max_bias == 0.0f; + GGML_ASSERT(use_gqa_opt); + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + GGML_ASSERT(gqa_ratio % 16 == 0); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + + //switch (Q->ne[0]) { + // case 64: + // GGML_ASSERT(V->ne[0] == 64); + // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64, 64>(ctx, dst); + // break; + // case 80: + // GGML_ASSERT(V->ne[0] == 80); + // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80, 80>(ctx, dst); + // break; + // case 96: + // GGML_ASSERT(V->ne[0] == 96); + // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96, 96>(ctx, dst); + // break; + // case 112: + // GGML_ASSERT(V->ne[0] == 112); + // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst); + // break; + // case 128: + // GGML_ASSERT(V->ne[0] == 128); + // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst); + // break; + // case 192: + // GGML_ASSERT(V->ne[0] == 128); + // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<192, 128>(ctx, dst); + // break; + // case 256: + // GGML_ASSERT(V->ne[0] == 256); + // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); + // break; + // case 576: { + // // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels. + // GGML_ASSERT(V->ne[0] == 512); + // float max_bias = 0.0f; + // memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + // const bool use_gqa_opt = mask && max_bias == 0.0f; + // GGML_ASSERT(use_gqa_opt); + + // GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + // const int gqa_ratio = Q->ne[2] / K->ne[2]; + // GGML_ASSERT(gqa_ratio % 16 == 0); + // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + // } break; + // default: + // GGML_ABORT("fatal error"); + // break; + //} +} + diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cuh b/ggml/src/ggml-cuda/fattn-new-mma.cuh new file mode 100644 index 00000000..40f867df --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-new-mma.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index ea52fa02..725b443d 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -13,6 +13,7 @@ #include "fattn-vec-f32.cuh" #include "fattn-wmma-f16.cuh" #include "fattn-mma-f16.cuh" +#include "fattn-new-mma.cuh" #include "fattn.cuh" #include @@ -517,12 +518,28 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } - // We need this because I haven't adapted the MMA kernels to work for different + // + // It turns out the new new MMA implementation is slower than the + // previous MMA implementation. + // Hence, we use it only for DeepSeek with MLA enabled, where head sizes are 576, 512, + // so no other implementation works. + // + if (new_mma_available(cc) && Q->ne[0] == 576) { + ggml_cuda_flash_attn_ext_mma_new(ctx, dst); + return; + } + + // + // We need this because I haven't adapted new MMA kernels to work for different // K and V head sizes. - if (K->ne[0] != V->ne[0]) { + // We also need it if the new MMA is not available + // + if (!new_mma_available(cc) || K->ne[0] != V->ne[0]) { ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); return; } + // As mentioned above, the new new MMA is slower than then the new MMA. ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); + //ggml_cuda_flash_attn_ext_mma_new(ctx, dst); } diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 3d991b4d..f87ebb96 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -61,7 +61,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) { } } -template +template static __device__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -73,10 +73,8 @@ static __device__ void mul_mat_vec_q( constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3)) - constexpr int nwarps = 1; constexpr int rows_per_cuda_block = 1; #else - constexpr int nwarps = ncols_y <= 4 ? 4 : 2; constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) @@ -139,10 +137,10 @@ static __device__ void mul_mat_vec_q( } } -template +template #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) // tell the compiler to use as many registers as it wants, see nwarps definition below -__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1) +__launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const char * __restrict__ ids_data, @@ -153,11 +151,11 @@ static __global__ void mul_mat_vec_q( const char * cx = (const char *)vx + i02*nb02; const char * cy = (const char *)vy + i2*nb12; char * cdst = (char *)dst + i2*nb2; - mul_mat_vec_q(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst); } -template -static void mul_mat_vec_q_cuda( +template +static void mul_mat_vec_q_cuda_T( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream) { @@ -167,61 +165,61 @@ static void mul_mat_vec_q_cuda( int id = ggml_cuda_get_device(); - int64_t nwarps = 1; - int64_t rows_per_cuda_block = 1; + int64_t rows_per_cuda_block = ggml_cuda_info().devices[id].cc < CC_RDNA2 ? + ncols_y < 4 ? 1 : 2 : 1; - if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2 - switch(ncols_y) { - case 1: - nwarps = 4; - rows_per_cuda_block = 1; - break; - case 2: - case 3: - case 4: - nwarps = 4; - rows_per_cuda_block = 2; - break; - case 5: - case 6: - case 7: - case 8: - nwarps = 2; - rows_per_cuda_block = 2; - break; - default: - GGML_ABORT("fatal error"); - break; - } - } + //if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2 + // switch(ncols_y) { + // case 1: + // nwarps = 4; + // rows_per_cuda_block = 1; + // break; + // case 2: + // case 3: + // case 4: + // nwarps = 4; + // rows_per_cuda_block = 2; + // break; + // case 5: + // case 6: + // case 7: + // case 8: + // nwarps = 2; + // rows_per_cuda_block = 2; + // break; + // default: + // GGML_ABORT("fatal error"); + // break; + // } + //} const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block; const dim3 block_nums(nblocks, ne2, 1); const dim3 block_dims(WARP_SIZE, nwarps, 1); switch (ncols_y) { case 1: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 2: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 3: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 4: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 5: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 6: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 7: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 8: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; default: GGML_ABORT("fatal error"); @@ -229,6 +227,31 @@ static void mul_mat_vec_q_cuda( } } +template +static void mul_mat_vec_q_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream) { + int nwarps = 1; + int id = ggml_cuda_get_device(); + if (ne2 < 2 && ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2 + nwarps = ncols_y <= 4 ? 4 : 2; + } + switch (nwarps) { + case 1: + mul_mat_vec_q_cuda_T(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, + ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case 2: + mul_mat_vec_q_cuda_T(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, + ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + default: + mul_mat_vec_q_cuda_T(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, + ne2, nb02, nb12, nb2, ids_nb0, stream); + } +} + static void mul_mat_vec_q4_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, From 4084ca7331611da4426d781a15a6ffa68312759e Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 7 May 2025 18:59:01 +0300 Subject: [PATCH 123/132] Update README.md --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 5054a173..46638dd7 100644 --- a/README.md +++ b/README.md @@ -9,9 +9,12 @@ This repository is a fork of [llama.cpp](https://github.com/ggerganov/llama.cpp) >[!IMPORTANT] >The new GGUFs for DeepSeek-V3/R1/Lite do not work in this repository. This is due to the backwards incompatible change in mainline `llama.cpp` that [added MLA support](https://github.com/ggml-org/llama.cpp/pull/12801) >2.5 months after MLA was available here, and worked with the original DeepSeek GGUFs. Please use the original GGUF or, if you don't have one, convert the HF safetensors using the Python conversion script in this repository. +> +>**Update** There is now [PR 394](https://github.com/ikawrakow/ik_llama.cpp/pull/394) addressing the issue. Would appreciate testing with DeepSeek-V3/R1. ## Latest News +* May 7 2025: 🚀 Faster TG for DeepSeek models with GPU or hybrid GPU/CPU inference. See [PR 386](https://github.com/ikawrakow/ik_llama.cpp/pull/386) for details. Caveat: Ampere or newer Nvidia GPU required * May 4 2025: 🚀 Significant token generation performance improvement on CUDA with Flash Attention for GQA models. For details and benchmarks see [PR #370](https://github.com/ikawrakow/ik_llama.cpp/pull/370) * April 29 2025: Qwen3 support added * April 26 2025: GLM-4 support added From bc6ae515ceb14eeaf198e00251a9689539cea176 Mon Sep 17 00:00:00 2001 From: saood06 Date: Fri, 9 May 2025 02:09:59 -0500 Subject: [PATCH 124/132] Support for Llama-3-Nemotron models (#377) * conflict resolution * Changes to make work and add longrope support * Changes to n_attention_wv rule * Untested support of 253B * DeciLMCausalModel now reads rope_theta from config.json properly * Remove errant Granite mentions * Better n_attention_vw rule * Update vocab.py --------- Co-authored-by: Yee Man Chan Co-authored-by: Iwan Kawrakow --- convert_hf_to_gguf.py | 178 ++++++++++++++++++++ gguf-py/gguf/constants.py | 26 +++ gguf-py/gguf/tensor_mapping.py | 3 +- gguf-py/gguf/vocab.py | 33 +++- include/llama.h | 3 +- src/llama.cpp | 288 ++++++++++++++++++++++++++++++++- 6 files changed, 523 insertions(+), 8 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 20d27a5c..16f97ab0 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1597,6 +1597,184 @@ class LlamaModel(Model): raise ValueError(f"Unprocessed experts: {experts}") +@Model.register("DeciLMForCausalLM") +class DeciModel(Model): + model_arch = gguf.MODEL_ARCH.DECI + + @staticmethod + def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int: + # DeciLM-specific code + intermediate_size = int(2 * ffn_mult * n_embd / 3) + return DeciModel._find_multiple(intermediate_size, 256) + + @staticmethod + def _find_multiple(n: int, k: int) -> int: + # DeciLM-specific code + if n % k == 0: + return n + return n + k - (n % k) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if "block_configs" in self.hparams: # Llama-3_1-Nemotron-51B + _block_configs: list[dict[str,Any]] = self.hparams["block_configs"] + assert self.block_count == len(_block_configs) + self._num_kv_heads = list() + self._num_heads = list() + _ffn_multipliers = list() + # ***linear attention layer*** + # if n_heads_in_group is None and replace_with_linear is True + # then _num_kv_heads[il] is 0 and _num_heads[il] is num_attention_heads + # ***attention-free layer*** + # if n_heads_in_group is None and replace_with_linear is False + # then _num_kv_heads[il] is 0 and _num_heads[il] is 0 + # ***normal attention-layer*** + # if n_heads_in_group is not None, then + # _num_kv_heads[il] is num_attention_head // n_heads_in_group and + # _num_heads[il] is num_attention_head + # ***dummy layer*** for nemotron 253B + # if n_heads_in_group is None and ffn_mult is None + # then _num_kv_heads[il] is 0 and _num_heads[il] is 0 and _ffn_dims is 0 + for il in range(len(_block_configs)): + if _block_configs[il]["attention"]["n_heads_in_group"] is None: + if _block_configs[il]["attention"]["replace_with_linear"] is True: + self._num_kv_heads.append(0) + self._num_heads.append(self.hparams["num_attention_heads"]) + else: + self._num_kv_heads.append(0) + self._num_heads.append(0) + else: + self._num_kv_heads.append(self.hparams["num_attention_heads"] // _block_configs[il]["attention"]["n_heads_in_group"]) + self._num_heads.append(self.hparams["num_attention_heads"]) + if _block_configs[il]["ffn"]["ffn_mult"] is None: # dummy layer + _ffn_multipliers.append(0.0) + else: + _ffn_multipliers.append(_block_configs[il]["ffn"]["ffn_mult"]) + assert self.block_count == len(self._num_kv_heads) + assert self.block_count == len(self._num_heads) + assert self.block_count == len(_ffn_multipliers) + assert isinstance(self._num_kv_heads, list) and isinstance(self._num_kv_heads[0], int) + assert isinstance(self._num_heads, list) and isinstance(self._num_heads[0], int) + assert isinstance(_ffn_multipliers, list) and isinstance(_ffn_multipliers[0], float) + self._ffn_dims: list[int] = [ + DeciModel._ffn_mult_to_intermediate_size(multiplier, self.hparams["hidden_size"]) + for multiplier in _ffn_multipliers + ] + + def set_vocab(self): + # Please change tokenizer_config.json of Llama-3_1-Nemotron-51B's + # eos_token from '|eot_id|' to '|end_of_text|' + if self.hparams.get("vocab_size", 128256) == 128256: + tokens, toktypes, tokpre = self.get_vocab_base() + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + special_vocab.add_to_gguf(self.gguf_writer) + else: + # DeciLM-7B + self._set_vocab_llama_hf() + + def set_gguf_parameters(self): + if "block_configs" in self.hparams: # Llama-3_1-Nemotron-51B + assert self.block_count == len(self._num_kv_heads) + assert self.block_count == len(self._num_heads) + assert self.block_count == len(self._ffn_dims) + if (rope_theta := self.hparams.get("rope_theta")) is not None: + self.gguf_writer.add_rope_freq_base(rope_theta) + self.gguf_writer.add_head_count_kv(self._num_kv_heads) + self.gguf_writer.add_head_count(self._num_heads) + self.gguf_writer.add_feed_forward_length(self._ffn_dims) + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) + self.gguf_writer.add_key_length(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_value_length(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_file_type(self.ftype) + else: # DeciLM-7B + super().set_gguf_parameters() + if "num_key_value_heads_per_layer" in self.hparams: # DeciLM-7B + self._num_kv_heads: list[int] = self.hparams["num_key_value_heads_per_layer"] + assert self.block_count == len(self._num_kv_heads) + self.gguf_writer.add_head_count_kv(self._num_kv_heads) + hparams = self.hparams + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + + if "head_dim" in hparams: + rope_dim = hparams["head_dim"] + else: + rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] + self.gguf_writer.add_rope_dimension_count(rope_dim) + + if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: + if self.hparams["rope_scaling"].get("type") == "linear": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + + @staticmethod + def permute(weights: Tensor, n_head: int, n_head_kv: int | None): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + n_head = self.hparams["num_attention_heads"] + if bid is not None: + if "num_key_value_heads_per_layer" in self.hparams: + n_kv_head = self.hparams["num_key_value_heads_per_layer"][bid] + elif "block_configs" in self.hparams: + n_kv_head = self._num_kv_heads[bid] + n_head = self._num_heads[bid] + else: + n_kv_head = self.hparams.get("num_key_value_heads") + else: + n_kv_head = self.hparams.get("num_key_value_heads") + + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = DeciModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = DeciModel.permute(data_torch, n_head, n_kv_head) + return [(self.map_tensor_name(name), data_torch)] + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): + if rope_scaling.get("rope_type", '').lower() == "llama3": + base = self.hparams.get("rope_theta", 10000.0) + dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + + factor = rope_scaling.get("factor", 8.0) + low_freq_factor = rope_scaling.get("low_freq_factor", 1.0) + high_freq_factor = rope_scaling.get("high_freq_factor", 4.0) + old_context_len = self.hparams.get("original_max_position_embeddings", 8192) + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + assert low_freq_wavelen != high_freq_wavelen + + rope_factors = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + rope_factors.append(1) + elif wavelen > low_freq_wavelen: + rope_factors.append(factor) + else: + smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + rope_factors.append(1 / ((1 - smooth) / factor + smooth)) + + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) + + def prepare_tensors(self): + super().prepare_tensors() + + @Model.register("BitnetForCausalLM") @Model.register("BitNetForCausalLM") class BitnetModel(Model): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index d6feb9ea..e2f4eb1a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -181,6 +181,7 @@ class GGUFType: class MODEL_ARCH(IntEnum): LLAMA = auto() + DECI = auto() FALCON = auto() BAICHUAN = auto() GROK = auto() @@ -316,6 +317,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.LLAMA: "llama", + MODEL_ARCH.DECI: "deci", MODEL_ARCH.FALCON: "falcon", MODEL_ARCH.BAICHUAN: "baichuan", MODEL_ARCH.GROK: "grok", @@ -470,6 +472,26 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.DECI: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], MODEL_ARCH.GROK: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -1149,6 +1171,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, ], + MODEL_ARCH.DECI: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], MODEL_ARCH.BAICHUAN: [ MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 1dea6a82..3ff70cd7 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -176,7 +176,8 @@ class TensorNameMap: "transformer.blocks.{bid}.attn.out_proj", # mpt "transformer.h.{bid}.self_attention.dense", # falcon "h.{bid}.self_attention.dense", # bloom - "model.layers.{bid}.self_attn.o_proj", # llama-hf + "model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 + "model.layers.{bid}.self_attn.linear_attn", # deci "layers.{bid}.attention.wo", # llama-pth "encoder.layer.{bid}.attention.output.dense", # bert "transformer.h.{bid}.attn.out_proj", # gpt-j diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index dc574991..cca09798 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -122,8 +122,30 @@ class SpecialVocab: tokenizer = json.load(f) if self.load_merges: merges = tokenizer.get('model', {}).get('merges') - if isinstance(merges, list) and merges and isinstance(merges[0], str): - self.merges = merges + if isinstance(merges, list) and merges: + if isinstance(merges[0], str): + self.merges = merges + elif isinstance(merges[0], list) and len(merges[0]) == 2 and isinstance(merges[0][0], str): + # New format since transformers 4.45 to support spaces in merges + # ref: https://github.com/ggml-org/llama.cpp/issues/9692 + # TODO: internally store as the new format instead of converting to old + if any(' ' in s for pair in merges for s in pair): + logger.warning(f'Spaces in merges detected, encoding as {chr(ord(" ") + 256)!r}') + self.merges = [ + ' '.join( + [ + # ensure the spaces are properly encoded + ''.join( + chr(ord(c) + 256) if c == ' ' else c + for c in part + ) + for part in pair + ] + ) + for pair in merges + ] + else: + raise ValueError("Unknown tokenizer merges format") added_tokens = tokenizer.get('added_tokens', {}) else: added_tokens = {} @@ -132,7 +154,12 @@ class SpecialVocab: return True with open(tokenizer_config_file, encoding = 'utf-8') as f: tokenizer_config = json.load(f) - chat_template = tokenizer_config.get('chat_template') + chat_template_alt = None + chat_template_file = path / 'chat_template.json' + if chat_template_file.is_file(): + with open(chat_template_file, encoding = 'utf-8') as f: + chat_template_alt = json.load(f).get('chat_template') + chat_template = tokenizer_config.get('chat_template', chat_template_alt) if chat_template is None or isinstance(chat_template, (str, list)): self.chat_template = chat_template else: diff --git a/include/llama.h b/include/llama.h index d7376d7d..e2901861 100644 --- a/include/llama.h +++ b/include/llama.h @@ -230,7 +230,8 @@ extern "C" { LLAMA_ROPE_SCALING_TYPE_NONE = 0, LLAMA_ROPE_SCALING_TYPE_LINEAR = 1, LLAMA_ROPE_SCALING_TYPE_YARN = 2, - LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN, + LLAMA_ROPE_SCALING_TYPE_LONGROPE = 3, + LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_LONGROPE, }; enum llama_pooling_type { diff --git a/src/llama.cpp b/src/llama.cpp index dd9d82dc..9fbb58d1 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -185,6 +185,7 @@ static std::string format(const char * fmt, ...) { enum llm_arch { LLM_ARCH_LLAMA, LLM_ARCH_LLAMA4, + LLM_ARCH_DECI, LLM_ARCH_FALCON, LLM_ARCH_BAICHUAN, LLM_ARCH_GROK, @@ -240,6 +241,7 @@ enum llm_arch { static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_LLAMA, "llama" }, { LLM_ARCH_LLAMA4, "llama4" }, + { LLM_ARCH_DECI, "deci" }, { LLM_ARCH_FALCON, "falcon" }, { LLM_ARCH_GROK, "grok" }, { LLM_ARCH_GPT2, "gpt2" }, @@ -282,7 +284,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_BITNET_25, "bitnet-25" }, - { LLM_ARCH_BITNET_B158, "bitnet-b1.58" }, + { LLM_ARCH_BITNET_B158, "bitnet-b1.58" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_JAIS, "jais" }, @@ -632,6 +634,32 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_DECI, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_LLAMA4, { @@ -2525,6 +2553,7 @@ enum e_model { MODEL_70B, MODEL_236B, MODEL_314B, + MODEL_405B, MODEL_671B, MODEL_SMALL, MODEL_MEDIUM, @@ -5128,6 +5157,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_70B: return "70B"; case MODEL_236B: return "236B"; case MODEL_314B: return "314B"; + case MODEL_405B: return "405B"; case MODEL_671B: return "671B"; case MODEL_SMALL: return "0.1B"; case MODEL_MEDIUM: return "0.4B"; @@ -5265,7 +5295,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); - if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON || model.arch == LLM_ARCH_BITNET_25 || model.arch == LLM_ARCH_BITNET_B158) { + if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON || model.arch == LLM_ARCH_BITNET_25 || model.arch == LLM_ARCH_BITNET_B158 || model.arch == LLM_ARCH_DECI) { if (hparams.n_rot != hparams.n_embd_head_k) { throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); } @@ -5320,6 +5350,16 @@ static void llm_load_hparams( if (model.type == MODEL_17B_128E) { hparams.use_kq_norm = false; + } + } break; + case LLM_ARCH_DECI: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 80: model.type = e_model::MODEL_70B; break; + case 162: model.type = e_model::MODEL_405B; break; + default: model.type = e_model::MODEL_UNKNOWN; } } break; case LLM_ARCH_MINICPM: @@ -6956,6 +6996,76 @@ static bool llm_load_tensors( } } } break; + case LLM_ARCH_DECI: + { + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); + const int64_t n_ff = hparams.n_ff(i); + const int64_t n_head = hparams.n_head(i); + const int64_t n_head_kv = hparams.n_head_kv(i); + + if (n_head_kv == 0 && n_head > 0) { + // linear attention for DeciLMCausalModel + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.wo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + } + else if (n_head_kv > 0) { + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.wq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}); + layer.wk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}); + layer.wv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}); + layer.wo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}); + } + + // optional bias tensors + + + layer.bq = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + if (n_ff > 0) { + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + } + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight"), { n_rot/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_rot/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + } + + if (n_ff > 0) { + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } + + // optional MLP bias + layer.ffn_gate_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + } + } break; case LLM_ARCH_LLAMA4: { model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -10485,6 +10595,168 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_deci() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_head = hparams.n_head(il); + const int64_t n_ff = hparams.n_ff(il); + + if (n_head == 0) { // attention-free layer of Llama-3_1-Nemotron-51B + cur = inpL; + } else { + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + } + + if (n_head > 0 && n_head_kv == 0) { // "linear attention" of Llama-3_1-Nemotron-51B + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); + cb(cur, "wo", il); + } else if (n_head > 0) { + // self-attention + // rope freq factors for llama3; may return nullptr for llama2 and other models + struct ggml_tensor * rope_factors = build_rope_factors(il); + + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // FFN-free layer of Llama-3_1-Nemotron-Ultra-253B + if (n_ff == 0) { + continue; + } + + if (hparams.f_residual_scale) { + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + } + + // modified to support attention-free layer of Llama-3_1-Nemotron-51B + struct ggml_tensor * ffn_inp = cur; + if (n_head > 0) { + ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + } + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } + + if (hparams.f_residual_scale) { + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + + if (hparams.f_logit_scale) { + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + } + + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + struct ggml_cgraph * build_baichuan() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -16477,6 +16749,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_llama(); } break; + case LLM_ARCH_DECI: + { + result = llm.build_deci(); + } break; case LLM_ARCH_BAICHUAN: { result = llm.build_baichuan(); @@ -18997,8 +19273,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // - qs.n_attention_wv == 0 for Mamba models // - qs.n_attention_wv == model.hparams.n_layer for Transformer models // - qs.n_attention_wv == 3 * model.hparams.n_layer for Encoder-Decoder models + // - model.arch == LLM_ARCH_DECI for Deci-Nemotron models // - GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer) && "n_attention_wv is unexpected"); + GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer || model.arch == LLM_ARCH_DECI) && "n_attention_wv is unexpected"); size_t total_size_org = 0; size_t total_size_new = 0; @@ -20298,6 +20575,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { // use what we call a normal RoPE, operating on pairs of consecutive head values case LLM_ARCH_LLAMA: + case LLM_ARCH_DECI: case LLM_ARCH_LLAMA4: case LLM_ARCH_BAICHUAN: case LLM_ARCH_STARCODER: @@ -20371,6 +20649,10 @@ int32_t llama_n_layer(const struct llama_model * model) { return model->hparams.n_layer; } +int32_t llama_n_head(const struct llama_model * model) { + return model->hparams.n_head(); +} + float llama_rope_freq_scale_train(const struct llama_model * model) { return model->hparams.rope_freq_scale_train; } From 496451a1d4c41300ebdb102f12401b8ffa5b1d4b Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 9 May 2025 10:13:25 +0300 Subject: [PATCH 125/132] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 46638dd7..ea4144bb 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ This repository is a fork of [llama.cpp](https://github.com/ggerganov/llama.cpp) ## Latest News +* May 9 2025: Support for LlaMA-3-Nmotron models added, see [PR 377](https://github.com/ikawrakow/ik_llama.cpp/pull/377) * May 7 2025: 🚀 Faster TG for DeepSeek models with GPU or hybrid GPU/CPU inference. See [PR 386](https://github.com/ikawrakow/ik_llama.cpp/pull/386) for details. Caveat: Ampere or newer Nvidia GPU required * May 4 2025: 🚀 Significant token generation performance improvement on CUDA with Flash Attention for GQA models. For details and benchmarks see [PR #370](https://github.com/ikawrakow/ik_llama.cpp/pull/370) * April 29 2025: Qwen3 support added From 8777fc4855dd1551c20a84cb266f75fa49e9b0e8 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 9 May 2025 10:22:48 +0300 Subject: [PATCH 126/132] Fix CUDA FlashMLA-3 with quantized KV cache (#400) Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda/fattn-new-mma.cu | 48 ++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index 796d9c7b..d1484451 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -1362,26 +1362,46 @@ void launch_fattn_new_mma( to_fp16(K_data, K_f16.ptr, 1, ggml_nelements(K), main_stream); K_data = (char *) K_f16.ptr; - const size_t bs = ggml_blck_size(K->type); - const size_t ts = ggml_type_size(K->type); + nb11 = K->ne[0]*sizeof(half); + nb12 = nb11*K->ne[1]; + nb13 = nb12*K->ne[2]; - nb11 = nb11*bs*sizeof(half)/ts; - nb12 = nb12*bs*sizeof(half)/ts; - nb13 = nb13*bs*sizeof(half)/ts; + // Original PR in llama.cpp. I don't think that can work when K is not contiguous (e.g., nb11 > nb12), there are + // gaps between the rows, etc., as ggml_get_to_fp16_cuda stores into contiguous memory. + //const size_t bs = ggml_blck_size(K->type); + //const size_t ts = ggml_type_size(K->type); + + //nb11 = nb11*bs*sizeof(half)/ts; + //nb12 = nb12*bs*sizeof(half)/ts; + //nb13 = nb13*bs*sizeof(half)/ts; } if (need_f16_V && V->type != GGML_TYPE_F16) { - V_f16.alloc(ggml_nelements(V)); - to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); - to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream); - V_data = (char *) V_f16.ptr; + if constexpr (DV == 512) { + // DeepSeek. In this case the V cache is the same as the K cache, except that + // it has 512 elements per row instead of 576. + nb21 = nb11; + nb22 = nb12; + nb23 = nb13; + V_data = K_data; + } else { + V_f16.alloc(ggml_nelements(V)); + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); + to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream); + V_data = (char *) V_f16.ptr; - const size_t bs = ggml_blck_size(V->type); - const size_t ts = ggml_type_size(V->type); + nb21 = K->ne[0]*sizeof(half); + nb22 = nb21*V->ne[1]; + nb23 = nb22*V->ne[2]; - nb21 = nb21*bs*sizeof(half)/ts; - nb22 = nb22*bs*sizeof(half)/ts; - nb23 = nb23*bs*sizeof(half)/ts; + // Original PR in llama.cpp. Same comment as above for the K cache. + //const size_t bs = ggml_blck_size(V->type); + //const size_t ts = ggml_type_size(V->type); + + //nb21 = nb21*bs*sizeof(half)/ts; + //nb22 = nb22*bs*sizeof(half)/ts; + //nb23 = nb23*bs*sizeof(half)/ts; + } } int parallel_blocks = 1; From e5a4a3ce78ce96b6822dcd6138a98c4d237ecc9b Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 9 May 2025 11:16:36 +0300 Subject: [PATCH 127/132] Update README.md @saood06 Thanks! --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ea4144bb..c1381cad 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ This repository is a fork of [llama.cpp](https://github.com/ggerganov/llama.cpp) ## Latest News -* May 9 2025: Support for LlaMA-3-Nmotron models added, see [PR 377](https://github.com/ikawrakow/ik_llama.cpp/pull/377) +* May 9 2025: Support for LlaMA-3-Nemotron models added, see [PR 377](https://github.com/ikawrakow/ik_llama.cpp/pull/377) * May 7 2025: 🚀 Faster TG for DeepSeek models with GPU or hybrid GPU/CPU inference. See [PR 386](https://github.com/ikawrakow/ik_llama.cpp/pull/386) for details. Caveat: Ampere or newer Nvidia GPU required * May 4 2025: 🚀 Significant token generation performance improvement on CUDA with Flash Attention for GQA models. For details and benchmarks see [PR #370](https://github.com/ikawrakow/ik_llama.cpp/pull/370) * April 29 2025: Qwen3 support added From 967a2e1860482397a6a386952d972ac1205474ad Mon Sep 17 00:00:00 2001 From: saood06 Date: Fri, 9 May 2025 09:17:41 -0500 Subject: [PATCH 128/132] Fix missing rope_freqs with convert_hf_to_gguf (#402) * lora : fix llama conversion script with ROPE_FREQS * convert : refactor rope_freqs generation This should also fix vocab-only conversion for Phi-3. * convert : adapt MiniCPM3 to separate rope_freqs insertion MiniCPM3's tokenizer is treated as a SentencePiece tokenizer to avoid having to run its custom Python code which mixes tokenization in the same file as tool calls. gguf-py : add long and short RoPE factors to tensor mappings Empty, but the key names are used to populate the mappings. --------- Co-authored-by: Xuan Son Nguyen Co-authored-by: Francis Couture-Harpin --- convert_hf_to_gguf.py | 23 ++++++++++++++++++----- convert_lora_to_gguf.py | 4 ++++ gguf-py/gguf/constants.py | 2 ++ gguf-py/gguf/tensor_mapping.py | 3 +++ 4 files changed, 27 insertions(+), 5 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 16f97ab0..966cfcd3 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -14,6 +14,7 @@ from enum import IntEnum from pathlib import Path from hashlib import sha256 from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast +from itertools import chain import math import numpy as np @@ -256,10 +257,14 @@ class Model: return False + # some models need extra generated tensors (like rope_freqs) + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + return () + def prepare_tensors(self): max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") - for name, data_torch in self.get_tensors(): + for name, data_torch in chain(self.generate_extra_tensors(), self.get_tensors()): # we don't need these if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): continue @@ -1559,7 +1564,7 @@ class LlamaModel(Model): return [(self.map_tensor_name(name), data_torch)] - def prepare_tensors(self): + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): if rope_scaling.get("rope_type", '').lower() == "llama3": base = self.hparams.get("rope_theta", 10000.0) @@ -1586,8 +1591,9 @@ class LlamaModel(Model): smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) rope_factors.append(1 / ((1 - smooth) / factor + smooth)) - self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) + def prepare_tensors(self): super().prepare_tensors() if self._experts is not None: @@ -2307,6 +2313,13 @@ class Phi3MiniModel(Model): self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_sliding_window(self.find_hparam(["sliding_window"])) + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + n_embd = self.find_hparam(["hidden_size", "n_embd"]) + n_head = self.find_hparam(["num_attention_heads", "n_head"]) + max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"]) + orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"]) + rope_dims = n_embd // n_head + # write rope scaling for long context (128k) model rope_scaling = self.find_hparam(['rope_scaling'], True) if rope_scaling is None: @@ -2336,8 +2349,8 @@ class Phi3MiniModel(Model): if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}') - self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32)) - self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32)) @Model.register("PlamoForCausalLM") diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index a88d0d4a..ef088034 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -331,6 +331,10 @@ if __name__ == '__main__': self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha) super().set_gguf_parameters() + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + # Never add extra tensors (e.g. rope_freqs) for LoRA adapters + return () + def get_tensors(self) -> Iterator[tuple[str, Tensor]]: tensor_map: dict[str, PartialLoraTensor] = {} diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index e2f4eb1a..6819979f 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -806,6 +806,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FACTORS_LONG, + MODEL_TENSOR.ROPE_FACTORS_SHORT, MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 3ff70cd7..9688b02c 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -82,6 +82,9 @@ class TensorNameMap: "rope.freqs", # llama-pth "rotary_pos_emb.inv_freq", # chatglm ), + + MODEL_TENSOR.ROPE_FACTORS_LONG: (), + MODEL_TENSOR.ROPE_FACTORS_SHORT: (), } block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = { From 43a154d8b8b0e9217114577442cecb224a488d45 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 9 May 2025 22:00:40 +0300 Subject: [PATCH 129/132] Handle incompatible DeepSeek GGUFs (#394) Co-authored-by: Iwan Kawrakow --- src/llama.cpp | 63 ++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 57 insertions(+), 6 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 9fbb58d1..d0f76c49 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3468,6 +3468,30 @@ static bool llama_kv_cache_init( cache.ctxs.push_back(ctx); } + if (model.arch == LLM_ARCH_DEEPSEEK2) { + bool have_wkv_b = true; + for (auto& l : model.layers) { + if (!l.wkv_b) { + have_wkv_b = false; + break; + } + } + if (!have_wkv_b) { + if (cparams.mla_attn != 1) { + LLAMA_LOG_WARN("=========================================================\n"); + LLAMA_LOG_WARN("%s: missing wkv_b tensor(s)\n", __func__); + LLAMA_LOG_WARN("%s: changing MLA from %d to 1\n", __func__, cparams.mla_attn); + if (cparams.mla_attn > 1) { + LLAMA_LOG_WARN("%s: ** Prompt processing performance will be crippled **\n", __func__); + } + LLAMA_LOG_WARN("=========================================================\n"); + // Sorry for the hack. + auto& non_cparams = const_cast(cparams); + non_cparams.mla_attn = 1; + } + } + } + if (model.arch == LLM_ARCH_DEEPSEEK2 && cparams.mla_attn) { // DeepSeek MLA cache.kv_l.reserve(n_layer); @@ -3497,7 +3521,7 @@ static bool llama_kv_cache_init( const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; const uint32_t kv_lora_rank = hparams.n_lora_kv; - LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); + //LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank); if (cparams.flash_attn) { ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size); ggml_format_name(kv, "cache_kv_l%d", i); @@ -5847,6 +5871,25 @@ static void llm_load_hparams( } break; case LLM_ARCH_DEEPSEEK2: { + if (hparams.n_head_kv() == 1) { + printf("==========================================================================\n"); + printf("Detected incompatible DeepSeek model.\n"); + printf("Will try to fix, but there are no guarantees\n\n"); + printf("*** Your prompt processing speed will be crippled ***\n\n"); + printf("Consider making your own ik_llama.cpp compatible model or\n"); + printf("ask the model provider to make one for you,\n"); + int n_nead_kv = hparams.n_gqa(); + if (n_nead_kv%16 != 0 || hparams.n_embd_head_k != 576 || hparams.n_embd_head_v != 512 || + hparams.n_rot != 64) { + printf("Sorry, uknown model => cannot fix it => bailing out\n"); + GGML_ABORT("Fatal error"); + } + for (auto& item : hparams.n_head_kv_arr) item = n_nead_kv; + hparams.n_embd_head_k = 192; + hparams.n_embd_head_v = 128; + printf("==========================================================================\n"); + //GGML_ABORT("Fatal error"); + } bool is_lite = (hparams.n_layer == 27); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); @@ -5859,7 +5902,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - if (hparams.expert_gating_func == 0) { + if (hparams.expert_gating_func == 0) { // for compatibility with existing DeepSeek V2 and V2.5 GGUFs // that have no expert_gating_func model parameter set hparams.expert_gating_func = LLM_EXPERT_GATING_FUNC_SOFTMAX; @@ -8419,10 +8462,18 @@ static bool llm_load_tensors( layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); } - layer.wkv_a_mqa = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}); - layer.wkv_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}); - layer.wk_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 1); - layer.wv_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v}, 1); + layer.wkv_a_mqa = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i),{n_embd, kv_lora_rank + (n_embd_head_qk_rope)}); + layer.wkv_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), + {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, llama_model_loader::TENSOR_NOT_REQUIRED); + if (!layer.wkv_b) { + // Incompatible mainline model. Let's see if we can still load it + layer.wk_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, 0); + layer.wv_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v, n_head}, 0); + + } else { + layer.wk_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 1); + layer.wv_b = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v}, 1); + } layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}); layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); From a2d24c97e5c5c28aeb3669dcc0044b69258a85ca Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 10 May 2025 18:52:54 +0300 Subject: [PATCH 130/132] TG improvements for MoE models (#404) * cuda: Remove unnecessary device to host copy of row ids We get 3-4% TG speed improvement for DeepSeek-Lite just from that. * CPU: fix get_rows when SER is used With smart experts reduction (SER), one potentially uses fewer experts than specified by the model. This is accomplished by setting the ID of the not seected tensors to -1. Most of the necessary stuff was implemented when I added the SER option, but I forgot to update get_rows() for not quantized tensors. As a result, we get random garbage for the weights of the not-selected epxerts, which leads to garbage output. This commit fixes it on the CPU. I'm not quite sure yet why the GPU is not working. * CUDA: fix TG with SER --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda.cu | 5 ----- ggml/src/ggml-cuda/mmvq.cu | 19 ++++++++++++++++++- ggml/src/ggml.c | 35 +++++++++++++++++++++-------------- 3 files changed, 39 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index ff6e064c..87f80d0c 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2505,11 +2505,6 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor dst_padded_col_size, next->src[0]->type, stream); CUDA_CHECK(cudaGetLastError()); - std::vector ids_host(ggml_nbytes(ids)); - const char * ids_dev = (const char *) ids->data; - CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); - CUDA_CHECK(cudaStreamSynchronize(stream)); - local_dst.ne[2] = 1; auto local_next = *next; diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index f87ebb96..bc26cce4 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -147,10 +147,27 @@ static __global__ void mul_mat_vec_q( const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0) { int i2 = blockIdx.y; + char * cdst = (char *)dst + i2*nb2; int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2; + if (i02 < 0) { +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3)) + constexpr int rows_per_cuda_block = 1; +#else + constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2; +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) + const int row0 = rows_per_cuda_block*blockIdx.x; + if (threadIdx.y == 0) { + dst = (float *)cdst; + for (int j = 0; j < ncols_y; ++j) { + if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) { + dst[j*nrows_dst + row0 + threadIdx.x] = 0; + } + } + } + return; + } const char * cx = (const char *)vx + i02*nb02; const char * cy = (const char *)vy + i2*nb12; - char * cdst = (char *)dst + i2*nb2; mul_mat_vec_q(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst); } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 4cd18a28..d82466e0 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -15911,11 +15911,14 @@ static void ggml_compute_forward_get_rows_f16( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - assert(i01 >= 0 && i01 < ne01); + if (i01 >= 0 && i01 < ne01) { + ggml_fp16_to_fp32_row( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + } else { + memset((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03, 0, nc*sizeof(float)); + } - ggml_fp16_to_fp32_row( - (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); } } @@ -15952,11 +15955,13 @@ static void ggml_compute_forward_get_rows_bf16( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - assert(i01 >= 0 && i01 < ne01); - - ggml_bf16_to_fp32_row( - (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + if (i01 >= 0 && i01 < ne01) { + ggml_bf16_to_fp32_row( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + } else { + memset((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03, 0, nc*sizeof(float)); + } } } @@ -15993,11 +15998,13 @@ static void ggml_compute_forward_get_rows_f32( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - assert(i01 >= 0 && i01 < ne01); - - ggml_vec_cpy_f32(nc, - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), - (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03)); + if (i01 >= 0 && i01 < ne01) { + ggml_vec_cpy_f32(nc, + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), + (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03)); + } else { + memset((char *)dst->data + i10*nb1 + i11*nb2 + i12*nb3, 0, nc*sizeof(float)); + } } } From 36e6e888b75ae93fb5aac212bb0e147d8379ae23 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 11 May 2025 08:12:47 +0300 Subject: [PATCH 131/132] Fix race in the CUDA DeepSeek FA kernel (#406) Reference: https://github.com/ggml-org/llama.cpp/pull/13438 Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda/fattn-new-mma.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index d1484451..8da96370 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -898,6 +898,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); } + __syncthreads(); + // Write back combined meta data: #pragma unroll for (int imeta = 0; imeta < nmeta; ++imeta) { From 504fb890d90ec27e5f4822b7bd772fa94d4d6aac Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 11 May 2025 12:22:19 +0300 Subject: [PATCH 132/132] Revert "Fix race in the CUDA DeepSeek FA kernel (#406)" This reverts commit 36e6e888b75ae93fb5aac212bb0e147d8379ae23. I should have tested. We get NaNs. --- ggml/src/ggml-cuda/fattn-new-mma.cu | 2 -- 1 file changed, 2 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index 8da96370..d1484451 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -898,8 +898,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); } - __syncthreads(); - // Write back combined meta data: #pragma unroll for (int imeta = 0; imeta < nmeta; ++imeta) {