diff --git a/common/common.cpp b/common/common.cpp index 75dd78e6..95e91bc1 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -906,6 +906,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.use_mmap = false; return true; } + if (arg == "-rtr" || arg == "--run-time-repack") { + params.repack_tensors = true; + params.use_mmap = false; + return true; + } if (arg == "--numa") { CHECK_ARG std::string value(argv[i]); @@ -1579,6 +1584,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param if (llama_supports_mmap()) { options.push_back({ "*", " --no-mmap", "do not memory-map model (slower load but may reduce pageouts if not using mlock)" }); } + options.push_back({ "*", " --run-time-repack", "repack tensors if interleaved variant is available"}); options.push_back({ "*", " --numa TYPE", "attempt optimizations that help on some NUMA systems\n" " - distribute: spread execution evenly over all nodes\n" " - isolate: only spawn threads on CPUs on the node that execution started on\n" @@ -2204,6 +2210,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & mparams.use_mmap = params.use_mmap; mparams.use_mlock = params.use_mlock; mparams.check_tensors = params.check_tensors; + mparams.repack_tensors = params.repack_tensors; if (params.kv_overrides.empty()) { mparams.kv_overrides = NULL; } else { @@ -3244,6 +3251,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict); fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs); fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false"); + fprintf(stream, "repack: %s # default: false\n", params.repack_tensors ? "true" : "false"); fprintf(stream, "penalize_nl: %s # default: false\n", sparams.penalize_nl ? "true" : "false"); fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride); diff --git a/common/common.h b/common/common.h index 486017ef..73d7d650 100644 --- a/common/common.h +++ b/common/common.h @@ -187,6 +187,7 @@ struct gpt_params { bool no_kv_offload = false; // disable KV offloading bool warmup = true; // warmup run bool check_tensors = false; // validate tensor data + bool repack_tensors = false; // repack tensors if interleaved variant is available std::string cache_type_k = "f16"; // KV cache data type for the K std::string cache_type_v = "f16"; // KV cache data type for the V diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index b470a088..3910aa1d 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -590,6 +590,9 @@ class Model: if chkhsh == "855059429035d75a914d1eda9f10a876752e281a054a7a3d421ef0533e5b6249": # ref: https://huggingface.co/HuggingFaceTB/SmolLM-135M res = "smollm" + if chkhsh == "877081d19cf6996e2c4ff0e1236341e9b7bde288f5311a56a937f0afbbb3aeb5": + # ref: https://huggingface.co/deepseek-ai/DeepSeek-V3 + res = "deepseek-v3" if res is None: logger.warning("\n") diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index d5a2d925..40af02f4 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -94,6 +94,7 @@ models = [ {"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", }, {"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", }, {"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", }, + {"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"}, ] diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 25e7c775..55f825fe 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -139,6 +139,8 @@ int main(int argc, char ** argv) { const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg); if (n_ctx_req > n_kv_max) { + printf("n_ctx_req = %d is greater than n_kv_max = %d for pp = %d, tg = %d, pl = %d\n", + n_ctx_req, n_kv_max, pp, tg, pl); continue; } diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 9e4fd266..75fe40d1 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -238,6 +238,7 @@ struct cmd_params { int reps; bool verbose; bool warmup; + bool repack = false; output_formats output_format; output_formats output_format_stderr; }; @@ -265,6 +266,7 @@ static const cmd_params cmd_params_defaults = { /* reps */ 5, /* verbose */ false, /* warmup */ true, + /* repack */ false, /* output_format */ MARKDOWN, /* output_format_stderr */ NONE, }; @@ -298,6 +300,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -oe, --output-err (default: %s)\n", output_format_str(cmd_params_defaults.output_format_stderr)); printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0"); printf(" -w, --warmup <0|1> (default: %s)\n", cmd_params_defaults.warmup ? "1" : "0"); + printf(" -rtr, --run-time-repack <0|1> (default: %s)\n", cmd_params_defaults.repack ? "1" : "0"); printf("\n"); printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n"); } @@ -571,6 +574,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } params.warmup = std::stoi(argv[i]); + } else if (arg == "-rtr" || arg == "--run-time-repack") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.repack = std::stoi(argv[i]); } else { invalid_param = true; break; @@ -623,6 +632,7 @@ struct cmd_params_instance { std::vector tensor_split; bool use_mmap; bool embeddings; + bool repack = false; llama_model_params to_llama_mparams() const { llama_model_params mparams = llama_model_default_params(); @@ -635,6 +645,7 @@ struct cmd_params_instance { mparams.main_gpu = main_gpu; mparams.tensor_split = tensor_split.data(); mparams.use_mmap = use_mmap; + mparams.repack_tensors = repack; return mparams; } @@ -646,6 +657,7 @@ struct cmd_params_instance { split_mode == other.split_mode && main_gpu == other.main_gpu && use_mmap == other.use_mmap && + repack == other.repack && tensor_split == other.tensor_split; } @@ -706,6 +718,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, + /* .repack = */ params.repack, }; instances.push_back(instance); } @@ -732,6 +745,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, + /* .repack = */ params.repack, }; instances.push_back(instance); } @@ -758,6 +772,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, + /* .repack = */ params.repack, }; instances.push_back(instance); } @@ -796,6 +811,7 @@ struct test { std::vector tensor_split; bool use_mmap; bool embeddings; + bool repack = false; int n_prompt; int n_gen; std::string test_time; @@ -822,6 +838,7 @@ struct test { tensor_split = inst.tensor_split; use_mmap = inst.use_mmap; embeddings = inst.embeddings; + repack = inst.repack; n_prompt = inst.n_prompt; n_gen = inst.n_gen; // RFC 3339 date-time format @@ -891,7 +908,7 @@ struct test { "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", "main_gpu", "no_kv_offload", "flash_attn", - "tensor_split", "use_mmap", "embeddings", + "tensor_split", "use_mmap", "embeddings", "repack", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts" @@ -912,7 +929,7 @@ struct test { } if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || - field == "flash_attn" || field == "use_mmap" || field == "embeddings") { + field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -947,7 +964,7 @@ struct test { std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), - tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), + tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), std::to_string(avg_ts()), std::to_string(stdev_ts()) @@ -1112,6 +1129,9 @@ struct markdown_printer : public printer { if (field == "use_mmap") { return 4; } + if (field == "repack") { + return 3; + } if (field == "test") { return 13; } @@ -1143,6 +1163,9 @@ struct markdown_printer : public printer { if (field == "use_mmap") { return "mmap"; } + if (field == "repack") { + return "rtr"; + } if (field == "embeddings") { return "embd"; } @@ -1198,6 +1221,9 @@ struct markdown_printer : public printer { if (params.embeddings.size() > 1 || params.embeddings != cmd_params_defaults.embeddings) { fields.emplace_back("embeddings"); } + if (params.repack != cmd_params_defaults.repack) { + fields.emplace_back("repack"); + } fields.emplace_back("test"); fields.emplace_back("t/s"); diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 0c6daa12..a7d80326 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -22,50 +22,74 @@ static const std::vector QUANT_OPTIONS = { { "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 4.70G, +0.0349 ppl @ LLaMA-v1-7B", }, { "Q6_0", LLAMA_FTYPE_MOSTLY_Q6_0, " 6.5 bpw quantization", }, { "IQ2_XXS", LLAMA_FTYPE_MOSTLY_IQ2_XXS, " 2.06 bpw quantization", }, + { "IQ2_XXS_R4",LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4,"IQ2_XXS repacked", }, { "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS, " 2.31 bpw quantization", }, + { "IQ2_XS_R4",LLAMA_FTYPE_MOSTLY_IQ2_XS_R4,"IQ2_XS repacked", }, { "IQ2_S", LLAMA_FTYPE_MOSTLY_IQ2_S, " 2.5 bpw quantization", }, { "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", }, + { "IQ2_M_R4", LLAMA_FTYPE_MOSTLY_IQ2_M_R4, " 2.7 bpw quantization", }, { "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", }, { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, { "IQ1_BN", LLAMA_FTYPE_MOSTLY_IQ1_BN, " 1.62 bpw quantization (Bitnet)", }, { "IQ2_BN", LLAMA_FTYPE_MOSTLY_IQ2_BN, " 2.00 bpw quantization (Bitnet)", }, + { "IQ2_BN_R4",LLAMA_FTYPE_MOSTLY_IQ2_BN_R4," 2.00 bpw quantization (Bitnet)", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, + { "Q2_K_R4", LLAMA_FTYPE_MOSTLY_Q2_K_R4, "Q2_K_S repacked", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, { "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", }, { "IQ3_KT", LLAMA_FTYPE_MOSTLY_IQ3_KT, " 3.125 bpw trellis quantization", }, { "IQ4_KT", LLAMA_FTYPE_MOSTLY_IQ4_KT, " 4.0 bpw trellis quantization", }, + { "IQ3_XXS_R4",LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4,"IQ3_XXS repacked", }, { "IQ3_S", LLAMA_FTYPE_MOSTLY_IQ3_S, " 3.44 bpw quantization", }, + { "IQ3_S_R4", LLAMA_FTYPE_MOSTLY_IQ3_S_R4, "IQ3_S repacked", }, { "IQ3_M", LLAMA_FTYPE_MOSTLY_IQ3_M, " 3.66 bpw quantization mix", }, { "Q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" }, + { "Q3_K_R4", LLAMA_FTYPE_MOSTLY_Q3_K_R4, "Q3_K_S repacked" }, { "IQ3_XS", LLAMA_FTYPE_MOSTLY_IQ3_XS, " 3.3 bpw quantization" , }, { "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S, " 2.75G, +0.5551 ppl @ LLaMA-v1-7B", }, { "Q3_K_M", LLAMA_FTYPE_MOSTLY_Q3_K_M, " 3.07G, +0.2496 ppl @ LLaMA-v1-7B", }, { "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L, " 3.35G, +0.1764 ppl @ LLaMA-v1-7B", }, { "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", }, + { "IQ4_NL_R4",LLAMA_FTYPE_MOSTLY_IQ4_NL_R4," 4.50 bpw non-linear quantization", }, + { "IQ4_XS_R4",LLAMA_FTYPE_MOSTLY_IQ4_XS_R4," 4.25 bpw non-linear quantization", }, + { "Q4_0_R4", LLAMA_FTYPE_MOSTLY_Q4_0_R4, " 4.50 bpw quantization", }, + { "Q5_0_R4", LLAMA_FTYPE_MOSTLY_Q5_0_R4, " 5.50 bpw quantization", }, + { "Q6_0_R4", LLAMA_FTYPE_MOSTLY_Q6_0_R4, " 6.50 bpw quantization", }, + { "Q8_0_R4", LLAMA_FTYPE_MOSTLY_Q8_0_R4, " 8.50 bpw quantization", }, { "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", }, { "IQ4_KS", LLAMA_FTYPE_MOSTLY_IQ4_KS, " 4.25 bpw non-linear quantization", }, + { "IQ4_KS_R4",LLAMA_FTYPE_MOSTLY_IQ4_KS_R4,"IQ4_KS repacked", }, { "IQ4_KSS", LLAMA_FTYPE_MOSTLY_IQ4_KSS, " 4.0 bpw non-linear quantization", }, { "IQ2_K", LLAMA_FTYPE_MOSTLY_IQ2_K, " 2.375 bpw non-linear quantization",}, + { "IQ2_K_R4", LLAMA_FTYPE_MOSTLY_IQ2_K_R4, "IQ2_K repacked",}, { "IQ2_KS", LLAMA_FTYPE_MOSTLY_IQ2_KS, " 2.1875 bpw non-linear quantization",}, { "IQ2_KT", LLAMA_FTYPE_MOSTLY_IQ2_KT, " 2.125 bpw trellis quantization", }, { "IQ3_K", LLAMA_FTYPE_MOSTLY_IQ3_K, " 3.44 bpw non-linear quantization", }, + { "IQ3_K_R4", LLAMA_FTYPE_MOSTLY_IQ3_K_R4, "IQ3_K repacked", }, { "IQ3_KL", LLAMA_FTYPE_MOSTLY_IQ3_KL, " 4 bpw non-linear quantization mix",}, { "IQ4_K", LLAMA_FTYPE_MOSTLY_IQ4_K, " 4.5 bpw non-linear quantization", }, + { "IQ4_K_R4", LLAMA_FTYPE_MOSTLY_IQ4_K_R4, "IQ4_K repacked", }, { "IQ5_K", LLAMA_FTYPE_MOSTLY_IQ5_K, " 5.5 bpw non-linear quantization", }, + { "IQ5_K_R4", LLAMA_FTYPE_MOSTLY_IQ5_K_R4, "IQ5_K repacked", }, { "IQ6_K", LLAMA_FTYPE_MOSTLY_IQ6_K, " 6.6 bpw non-linear quantization", }, { "Q4_K", LLAMA_FTYPE_MOSTLY_Q4_K_M, "alias for Q4_K_M", }, + { "Q4_K_R4", LLAMA_FTYPE_MOSTLY_Q4_K_R4, "Q4_K_S repacked", }, { "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S, " 3.59G, +0.0992 ppl @ LLaMA-v1-7B", }, { "Q4_K_M", LLAMA_FTYPE_MOSTLY_Q4_K_M, " 3.80G, +0.0532 ppl @ LLaMA-v1-7B", }, { "Q5_K", LLAMA_FTYPE_MOSTLY_Q5_K_M, "alias for Q5_K_M", }, + { "Q5_K_R4", LLAMA_FTYPE_MOSTLY_Q5_K_R4, "Q5_K_S repacked", }, { "Q5_K_S", LLAMA_FTYPE_MOSTLY_Q5_K_S, " 4.33G, +0.0400 ppl @ LLaMA-v1-7B", }, { "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 4.45G, +0.0122 ppl @ LLaMA-v1-7B", }, { "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, +0.0008 ppl @ LLaMA-v1-7B", }, + { "Q6_K_R4", LLAMA_FTYPE_MOSTLY_Q6_K_R4, "Q6_K repacked", }, + { "Q8_K_R8", LLAMA_FTYPE_MOSTLY_Q8_K_R8, "Q8_K repacked", }, { "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", }, { "Q4_0_4_4", LLAMA_FTYPE_MOSTLY_Q4_0_4_4, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, { "Q4_0_4_8", LLAMA_FTYPE_MOSTLY_Q4_0_4_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, { "Q4_0_8_8", LLAMA_FTYPE_MOSTLY_Q4_0_8_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, { "F16", LLAMA_FTYPE_MOSTLY_F16, "14.00G, -0.0020 ppl @ Mistral-7B", }, { "BF16", LLAMA_FTYPE_MOSTLY_BF16, "14.00G, -0.0050 ppl @ Mistral-7B", }, + { "BF16_R16", LLAMA_FTYPE_MOSTLY_BF16_R16, "14.00G, -0.0050 ppl @ Mistral-7B", }, { "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", }, // Note: Ensure COPY comes after F32 to avoid ftype 0 from matching. { "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", }, @@ -486,8 +510,8 @@ int main(int argc, char ** argv) { if (!params.ignore_imatrix_rules && imatrix_data.empty() && (params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || - params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || - params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || + params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4 || + params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS_R4 || params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_M)) { fprintf(stderr, "\n==========================================================================================================\n"); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 380e7dfd..c88a5a1f 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -392,6 +392,12 @@ extern "C" { GGML_TYPE_Q4_0_4_8 = 32, GGML_TYPE_Q4_0_8_8 = 33, // + // So we are able to consume MS BitNet I2_S quants + // + GGML_TYPE_I2_S = 36, + // + GGML_TYPE_Q8_0_X4 = 98, + GGML_TYPE_Q8_1_X4 = 99, GGML_TYPE_Q6_0 = 133, GGML_TYPE_IQ1_BN = 134, GGML_TYPE_IQ2_BN = 135, @@ -406,9 +412,37 @@ extern "C" { GGML_TYPE_IQ4_KS = 144, GGML_TYPE_IQ2_KS = 145, GGML_TYPE_IQ4_KSS = 146, - GGML_TYPE_IQ2_KT = 147, - GGML_TYPE_IQ3_KT = 148, - GGML_TYPE_IQ4_KT = 149, + GGML_TYPE_Q8_K16 = 147, + GGML_TYPE_Q8_K32 = 148, + GGML_TYPE_Q8_KR8 = 149, + GGML_TYPE_IQ2_KT = 150, + GGML_TYPE_IQ3_KT = 151, + GGML_TYPE_IQ4_KT = 152, + + GGML_TYPE_Q4_0_R4 = 202, + GGML_TYPE_Q5_0_R4 = 206, + GGML_TYPE_Q8_0_R4 = 208, + GGML_TYPE_Q2_K_R4 = 210, + GGML_TYPE_Q3_K_R4 = 211, + GGML_TYPE_Q4_K_R4 = 212, + GGML_TYPE_Q5_K_R4 = 213, + GGML_TYPE_Q6_K_R4 = 214, + GGML_TYPE_IQ2_XXS_R4= 216, + GGML_TYPE_IQ2_XS_R4 = 217, + GGML_TYPE_IQ3_XXS_R4= 218, + GGML_TYPE_IQ4_NL_R4 = 220, + GGML_TYPE_IQ3_S_R4 = 221, + GGML_TYPE_IQ2_S_R4 = 222, + GGML_TYPE_IQ4_XS_R4 = 223, + GGML_TYPE_BF16_R16 = 230, + GGML_TYPE_Q6_0_R4 = 233, + GGML_TYPE_IQ2_BN_R4 = 335, + GGML_TYPE_IQ2_K_R4 = 337, + GGML_TYPE_IQ3_K_R4 = 338, + GGML_TYPE_IQ4_K_R4 = 339, + GGML_TYPE_IQ5_K_R4 = 340, + GGML_TYPE_IQ4_KS_R4 = 344, + GGML_TYPE_Q8_K_R8 = 399, GGML_TYPE_COUNT, }; @@ -470,6 +504,31 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ2_KT = 140, // except 1d tensors GGML_FTYPE_MOSTLY_IQ3_KT = 141, // except 1d tensors GGML_FTYPE_MOSTLY_IQ4_KT = 142, // except 1d tensors + // + GGML_FTYPE_MOSTLY_Q4_0_R4 = 202, // except 1d tensors + GGML_FTYPE_MOSTLY_Q8_0_R4 = 207, // except 1d tensors + GGML_FTYPE_MOSTLY_Q5_0_R4 = 208, // except 1d tensors + GGML_FTYPE_MOSTLY_Q2_K_R4 = 210, // except 1d tensors + GGML_FTYPE_MOSTLY_Q3_K_R4 = 211, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_K_R4 = 212, // except 1d tensors + GGML_FTYPE_MOSTLY_Q5_K_R4 = 213, // except 1d tensors + GGML_FTYPE_MOSTLY_Q6_K_R4 = 214, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ2_XXS_R4= 215, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ2_XS_R4 = 216, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ3_XXS_R4= 217, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ4_NL_R4 = 219, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ3_S_R4 = 220, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ2_S_R4 = 221, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ4_XS_R4 = 222, // except 1d tensors + GGML_FTYPE_MOSTLY_BF16_R16 = 224, // except 1d tensors + GGML_FTYPE_MOSTLY_Q6_0_R4 = 227, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ2_BN_R4 = 329, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ2_K_R4 = 330, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ3_K_R4 = 331, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ4_K_R4 = 332, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ5_K_R4 = 333, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ4_KS_R4 = 337, // except 1d tensors + GGML_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors }; // available tensor operations: diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 54999d2f..27308140 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -182,6 +182,13 @@ typedef struct { } block_q5_0; static_assert(sizeof(block_q5_0) == sizeof(ggml_half) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); +typedef struct { + ggml_half d[4]; // delta + uint8_t qh[QK5_0/2]; // 5-th bit of quants + uint8_t qs[QK5_0*2]; // nibbles / quants +} block_q5_0_r4; +static_assert(sizeof(block_q5_0_r4) == 4*sizeof(ggml_half) + QK5_0*2 + QK5_0/2, "wrong q5_0_r4 block size/padding"); + #define QK5_1 32 typedef struct { GGML_SCALE_TYPE1(m, dm); @@ -198,6 +205,13 @@ typedef struct { } block_q6_0; static_assert(sizeof(block_q6_0) == sizeof(ggml_half) + QK6_0/2 + QK6_0/4, "wrong q6_0 block size/padding"); +typedef struct { + ggml_half d[4]; // delta + uint8_t qh[QK6_0]; // 5+6-th bit of quants + uint8_t qs[QK6_0*2]; // nibbles / quants +} block_q6_0_r4; +static_assert(sizeof(block_q6_0_r4) == 4*sizeof(ggml_half) + QK6_0*2 + QK6_0, "wrong q6_0_r4 block size/padding"); + #define QK8_0 32 typedef struct { ggml_half d; // delta @@ -262,6 +276,13 @@ typedef struct { } block_q2_K; static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_half) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); +typedef struct { + ggml_half d[8]; + uint8_t scales[QK_K/4]; // scales and mins, quantized with 4 bits + uint8_t qs[QK_K]; // quants +} block_q2_k_r4; +static_assert(sizeof(block_q2_k_r4) == 8*sizeof(ggml_half) + QK_K/4 + QK_K, "wrong q2_k_r4 block size/padding"); + // 3-bit quantization // weight is represented as x = a * q // 16 blocks of 16 elements each @@ -274,6 +295,15 @@ typedef struct { } block_q3_K; static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding"); +typedef struct { + ggml_half d[4]; // super-block scales + uint8_t scales_h[QK_K/16]; // scales quantized with 6 bits (high 2 bits) + uint8_t scales_l[QK_K/8]; // scales quantized with 6 bits (low 4 bits) + uint8_t qh[QK_K/2]; // quants - high bit + uint8_t qs[QK_K]; // quants - low 2 bits +} block_q3_k_r4; +static_assert(sizeof(block_q3_k_r4) == 4*sizeof(ggml_half) + QK_K/16 + QK_K/8 + QK_K/2 + QK_K, "wrong q3_k_r4 block size/padding"); + // 4-bit quantization // 8 blocks of 32 elements each // weight is represented as x = a * q + b @@ -285,6 +315,14 @@ typedef struct { } block_q4_K; static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding"); +typedef struct { + ggml_half d[8]; + uint8_t scales_h[QK_K/16];// scales and mins, quantized with 6 bits + uint8_t scales_l[QK_K/8]; // scales and mins, quantized with 6 bits + uint8_t qs[QK_K*2]; // 4--bit quants +} block_q4_k_r4; +static_assert(sizeof(block_q4_k_r4) == 8*sizeof(ggml_half) + QK_K/16 + QK_K/8 + QK_K*2, "wrong q4_k_r4 block size/padding"); + // 5-bit quantization // 8 blocks of 32 elements each // weight is represented as x = a * q + b @@ -297,6 +335,15 @@ typedef struct { } block_q5_K; static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +typedef struct { + ggml_half d[8]; + uint8_t scales_h[QK_K/16];// scales and mins, quantized with 6 bits + uint8_t scales_l[QK_K/8]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K/2]; // quants, high bit + uint8_t qs[QK_K*2]; // quants, low 4 bits +} block_q5_k_r4; +static_assert(sizeof(block_q5_k_r4) == 8*sizeof(ggml_half) + QK_K/16 + QK_K/8 + QK_K/2 + QK_K*2, "wrong q5_k_r4 block size/padding"); + // 6-bit quantization // weight is represented as x = a * q // 16 blocks of 16 elements each @@ -309,6 +356,14 @@ typedef struct { } block_q6_K; static_assert(sizeof(block_q6_K) == sizeof(ggml_half) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding"); +typedef struct { + ggml_half d[4]; // super-block scale + int8_t scales[QK_K/4]; // scales, quantized with 8 bits + uint8_t qh[QK_K]; // quants, upper 2 bits + uint8_t ql[QK_K*2]; // quants, lower 4 bits +} block_q6_k_r4; +static_assert(sizeof(block_q6_k_r4) == 4*sizeof(ggml_half) + QK_K/4 + 3*QK_K, "wrong q6_k_r4 block size/padding"); + // This is only used for intermediate quantization and dot products typedef struct { float d; // delta @@ -327,6 +382,12 @@ typedef struct { } block_q8_K128; static_assert(sizeof(block_q8_K128) == sizeof(float) + 128, "wrong q8_K128 block size/padding"); +typedef struct { + ggml_half d[8]; // delta + int8_t qs[8*QK_K]; // quants, stored as unsigned ints +} block_q8_k_r8; +static_assert(sizeof(block_q8_k_r8) == 8*sizeof(ggml_half) + 8*QK_K, "wrong q8_k_r8 block size/padding"); + // (Almost) "true" 2-bit quantization. // Due to the need to use blocks as per ggml design, it ends up using // 2.0625 bpw because of the 16-bit scale for each block of 256. @@ -336,6 +397,13 @@ typedef struct { } block_iq2_xxs; static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_half) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding"); +typedef struct { + ggml_half d[4]; + uint8_t sas[QK_K/2]; + uint8_t qs[QK_K/2]; +} block_iq2_xxs_r4; +static_assert(sizeof(block_iq2_xxs_r4) == 4*sizeof(block_iq2_xxs), "wrong iq2_xxs_r4 block size/padding"); + // 2.3125 bpw quants typedef struct { ggml_half d; @@ -344,6 +412,13 @@ typedef struct { } block_iq2_xs; static_assert(sizeof(block_iq2_xs) == sizeof(ggml_half) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding"); +typedef struct { + ggml_half d[4]; + uint16_t qs[QK_K/2]; + uint8_t scales[QK_K/8]; +} block_iq2_xs_r4; +static_assert(sizeof(block_iq2_xs_r4) == 4*sizeof(block_iq2_xs), "wrong iq2_xs_r4 block size/padding"); + // 2.5625 bpw quants typedef struct { ggml_half d; @@ -353,6 +428,15 @@ typedef struct { } block_iq2_s; static_assert(sizeof(block_iq2_s) == sizeof(ggml_half) + QK_K/4 + QK_K/16, "wrong iq2_s block size/padding"); +typedef struct { + ggml_half d[4]; + uint8_t qs[QK_K/2]; + uint8_t qh[QK_K/8]; + uint8_t signs[QK_K/2]; + uint8_t scales[QK_K/8]; +} block_iq2_s_r4; +static_assert(sizeof(block_iq2_s_r4) == 4*sizeof(block_iq2_s), "wrong iq2_s_r4 block size/padding"); + // (Almost) "true" 3-bit quantization. // Due to the need to use blocks as per ggml design, it ends up using // 3.0625 bpw because of the 16-bit scale for each block of 256. @@ -362,6 +446,13 @@ typedef struct { } block_iq3_xxs; static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_half) + 3*(QK_K/8), "wrong iq3_xxs block size/padding"); +typedef struct { + ggml_half d[4]; + uint8_t sas[QK_K/2]; + uint8_t qs[QK_K]; +} block_iq3_xxs_r4; +static_assert(sizeof(block_iq3_xxs_r4) == 4*sizeof(block_iq3_xxs), "wrong iq3_xxs_r4 block size/padding"); + // 3.4375 bpw #define IQ3S_N_SCALE QK_K/64 typedef struct { @@ -373,6 +464,15 @@ typedef struct { } block_iq3_s; static_assert(sizeof(block_iq3_s) == sizeof(ggml_half) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding"); +typedef struct { + ggml_half d[4]; + uint8_t qs[QK_K]; + uint8_t qh[QK_K/8]; + uint8_t signs[QK_K/2]; + uint8_t scales[4*IQ3S_N_SCALE]; +} block_iq3_s_r4; +static_assert(sizeof(block_iq3_s_r4) == 4*sizeof(block_iq3_s), "wrong iq3_s_r4 block size/padding"); + typedef struct { ggml_half d; uint8_t qs[QK_K/8]; @@ -419,6 +519,11 @@ typedef struct { uint8_t qs[QK4_NL/2]; } block_iq4_nl; static_assert(sizeof(block_iq4_nl) == sizeof(ggml_half) + QK4_NL/2, "wrong iq4_nl block size/padding"); +typedef struct { + ggml_half d[4]; + uint8_t qs[2*QK4_NL]; +} block_iq4_nl_r4; +static_assert(sizeof(block_iq4_nl_r4) == 4*sizeof(ggml_half) + 2*QK4_NL, "wrong iq4_nl_x4 block size/padding"); typedef struct { ggml_half d; @@ -428,12 +533,26 @@ typedef struct { } block_iq4_xs; static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); +typedef struct { + ggml_half d[4]; + uint8_t scales_h[QK_K/32]; + uint8_t scales_l[QK_K/16]; + uint8_t qs[QK_K*2]; +} block_iq4_xs_r4; +static_assert(sizeof(block_iq4_xs_r4) == 4*sizeof(ggml_half) + QK_K/32 + QK_K/16 + QK_K*2, "wrong iq4_xs_rs block size/padding"); + typedef struct { uint8_t scales[QK_K/32]; uint8_t qs[QK_K/2]; } block_iq4_ks; static_assert(sizeof(block_iq4_ks) == QK_K/32 + QK_K/2, "wrong iq4_ks block size/padding"); +typedef struct { + uint8_t scales[QK_K/8]; + uint8_t qs[QK_K*2]; +} block_iq4_ks_r4; +static_assert(sizeof(block_iq4_ks_r4) == 4*sizeof(block_iq4_ks), "wrong iq4_ks_r4 block size/padding"); + typedef struct { uint32_t qs[QK_K/8]; } block_iq4_kss; @@ -447,6 +566,14 @@ typedef struct { } block_iq2_k; static_assert(sizeof(block_iq2_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/32 + QK_K/4, "wrong iq2_k block size/padding"); +typedef struct { + ggml_half d[4]; + uint8_t extra[8]; + uint8_t scales[QK_K/8]; + uint8_t qs[QK_K]; +} block_iq2_k_r4; +static_assert(sizeof(block_iq2_k_r4) == 4*sizeof(block_iq2_k), "wrong iq2_k_r4 block size/padding"); + typedef struct { uint16_t extra; uint8_t scales[QK_K/64]; @@ -482,6 +609,16 @@ typedef struct { } block_iq3_k; static_assert(sizeof(block_iq3_k) == sizeof(ggml_half) + 2*sizeof(uint16_t) + QK_K/32 + QK_K/4 + QK_K/8, "wrong iq3_k block size/padding"); +typedef struct { + ggml_half d[4]; + uint8_t extra[8]; + uint8_t scales_h[QK_K/32]; + uint8_t scales_l[QK_K/8]; + uint8_t qs[QK_K]; + uint8_t qh[QK_K/2]; +} block_iq3_k_r4; +static_assert(sizeof(block_iq3_k_r4) == 4*sizeof(block_iq3_k), "wrong iq3_k_r4 block size/padding"); + typedef struct { ggml_half d; uint16_t extra; @@ -491,6 +628,15 @@ typedef struct { } block_iq4_k; static_assert(sizeof(block_iq4_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/2 + 3*QK_K/64, "wrong iq4_k block size/padding"); +typedef struct { + ggml_half d[4]; + uint8_t extra[8]; + uint8_t scales_h[QK_K/16]; + uint8_t scales_l[QK_K/8]; + uint8_t qs[QK_K*2]; +} block_iq4_k_r4; +static_assert(sizeof(block_iq4_k_r4) == 4*sizeof(block_iq4_k), "wrong iq4_k_r4 block size/padding"); + typedef struct { ggml_half d; uint16_t extra; @@ -501,6 +647,16 @@ typedef struct { } block_iq5_k; static_assert(sizeof(block_iq5_k) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/2 + QK_K/8 + 3*QK_K/64, "wrong iq5_k block size/padding"); +typedef struct { + ggml_half d[4]; + uint8_t extra[8]; + uint8_t scales_h[QK_K/16]; + uint8_t scales_l[QK_K/8 ]; + uint8_t qs[QK_K*2]; + uint8_t qh[QK_K/2]; +} block_iq5_k_r4; +static_assert(sizeof(block_iq5_k_r4) == 4*sizeof(block_iq5_k), "wrong iq5_k_r4 block size/padding"); + typedef struct { ggml_half d; uint16_t extra; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index e6295b9d..074b2441 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -934,13 +934,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) block_q8_0 * restrict y = vy; -#if GGML_USE_IQK_MULMAT - const int nb4 = 4*(nb/4); -#else - const int nb4 = -1; -#endif #if defined(__ARM_NEON) - block_q8_0_x4 * y4 = (block_q8_0_x4 *)vy; for (int i = 0; i < nb; i++) { int i4 = i/4, ir = i%4; float32x4_t srcv [8]; @@ -959,27 +953,16 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) const float d = amax / ((1 << 7) - 1); const float id = d ? 1.0f/d : 0.0f; - if (i < nb4) { - y4[i4].d[ir] = GGML_FP32_TO_FP16(d); - } else { - y[i].d = GGML_FP32_TO_FP16(d); - } + y[i].d = GGML_FP32_TO_FP16(d); for (int j = 0; j < 8; j++) { const float32x4_t v = vmulq_n_f32(srcv[j], id); const int32x4_t vi = vcvtnq_s32_f32(v); - if (i < nb4) { - y4[i4].qs[32*ir + 4*j + 0] = vgetq_lane_s32(vi, 0); - y4[i4].qs[32*ir + 4*j + 1] = vgetq_lane_s32(vi, 1); - y4[i4].qs[32*ir + 4*j + 2] = vgetq_lane_s32(vi, 2); - y4[i4].qs[32*ir + 4*j + 3] = vgetq_lane_s32(vi, 3); - } else { - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); - } + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); } } #elif defined(__wasm_simd128__) @@ -1016,14 +999,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) } } #elif defined(__AVX2__) || defined(__AVX__) - block_q8_0_x4 * y4 = (block_q8_0_x4 *)vy; -#ifdef __AVX2__ - const bool pack = true; -#else - const bool pack = false; -#endif for (int i = 0; i < nb; i++) { - int i4 = i/4, ir = i%4; // Load elements into 4 AVX vectors __m256 v0 = _mm256_loadu_ps( x ); __m256 v1 = _mm256_loadu_ps( x + 8 ); @@ -1045,11 +1021,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) // Quantize these floats const float d = maxScalar / 127.f; - if (pack && i < nb4) { - y4[i4].d[ir] = GGML_FP32_TO_FP16(d); - } else { - y[i].d = GGML_FP32_TO_FP16(d); - } + y[i].d = GGML_FP32_TO_FP16(d); const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; const __m256 mul = _mm256_set1_ps( id ); @@ -1084,11 +1056,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); i0 = _mm256_permutevar8x32_epi32( i0, perm ); - if (i < nb4) { - _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0); - } else { - _mm256_storeu_si256((__m256i *)y[i].qs, i0); - } + _mm256_storeu_si256((__m256i *)y[i].qs, i0); #else // Since we don't have in AVX some necessary functions, // we split the registers in half and call AVX2 analogs from SSE @@ -1287,15 +1255,8 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) block_q8_1 * restrict y = vy; -#if GGML_USE_IQK_MULMAT - const int nb4 = 4*(nb/4); -#else - const int nb4 = -1; -#endif #if defined(__ARM_NEON) - block_q8_1_x4 * restrict y4 = vy; for (int i = 0; i < nb; i++) { - int i4 = i/4, ir = i%4; float32x4_t srcv [8]; float32x4_t asrcv[8]; float32x4_t amaxv[8]; @@ -1312,11 +1273,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) const float d = amax / ((1 << 7) - 1); const float id = d ? 1.0f/d : 0.0f; - if (i < nb4) { - y4[i4].d[ir] = GGML_FP32_TO_FP16(d); - } else { - y[i].d = GGML_FP32_TO_FP16(d); - } + y[i].d = GGML_FP32_TO_FP16(d); int32x4_t accv = vdupq_n_s32(0); @@ -1324,26 +1281,15 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) const float32x4_t v = vmulq_n_f32(srcv[j], id); const int32x4_t vi = vcvtnq_s32_f32(v); - if (i < nb4) { - y4[i4].qs[QK8_1*ir + 4*j + 0] = vgetq_lane_s32(vi, 0); - y4[i4].qs[QK8_1*ir + 4*j + 1] = vgetq_lane_s32(vi, 1); - y4[i4].qs[QK8_1*ir + 4*j + 2] = vgetq_lane_s32(vi, 2); - y4[i4].qs[QK8_1*ir + 4*j + 3] = vgetq_lane_s32(vi, 3); - } else { - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); - } + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); accv = vaddq_s32(accv, vi); } - if (i < nb4) { - y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); - } else { - y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); - } + y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); } #elif defined(__wasm_simd128__) for (int i = 0; i < nb; i++) { @@ -1389,14 +1335,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) wasm_i32x4_extract_lane(accv, 3))); } #elif defined(__AVX2__) || defined(__AVX__) - block_q8_1_x4 * restrict y4 = vy; -#ifdef __AVX2__ - const bool pack = true; -#else - const bool pack = false; -#endif for (int i = 0; i < nb; i++) { - int i4 = i/4, ir = i%4; // Load elements into 4 AVX vectors __m256 v0 = _mm256_loadu_ps( x ); __m256 v1 = _mm256_loadu_ps( x + 8 ); @@ -1418,11 +1357,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) // Quantize these floats const float d = max_scalar / 127.f; - if (pack && i < nb4) { - y4[i4].d[ir] = GGML_FP32_TO_FP16(d); - } else { - y[i].d = GGML_FP32_TO_FP16(d); - } + y[i].d = GGML_FP32_TO_FP16(d); const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; const __m256 mul = _mm256_set1_ps( id ); @@ -1446,11 +1381,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) #if defined(__AVX2__) // Compute the sum of the quants and set y[i].s - if (i < nb4) { - y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); - } else { - y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); - } + y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); // Convert int32 to int16 i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 @@ -1464,11 +1395,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); i0 = _mm256_permutevar8x32_epi32( i0, perm ); - if (i < nb4) { - _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0); - } else { - _mm256_storeu_si256((__m256i *)y[i].qs, i0); - } + _mm256_storeu_si256((__m256i *)y[i].qs, i0); #else // Since we don't have in AVX some necessary functions, // we split the registers in half and call AVX2 analogs from SSE @@ -15199,6 +15126,29 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_IQ6_K: break; case GGML_TYPE_IQ4_KS: break; case GGML_TYPE_IQ4_KSS: break; + case GGML_TYPE_IQ4_NL_R4: break; + case GGML_TYPE_IQ4_XS_R4: break; + case GGML_TYPE_IQ2_XXS_R4: break; + case GGML_TYPE_IQ2_XS_R4: break; + case GGML_TYPE_IQ3_XXS_R4: break; + case GGML_TYPE_IQ3_S_R4: break; + case GGML_TYPE_IQ2_S_R4: break; + case GGML_TYPE_Q4_0_R4: break; + case GGML_TYPE_Q5_0_R4: break; + case GGML_TYPE_Q6_0_R4: break; + case GGML_TYPE_Q8_0_R4: break; + case GGML_TYPE_Q2_K_R4: break; + case GGML_TYPE_Q3_K_R4: break; + case GGML_TYPE_Q4_K_R4: break; + case GGML_TYPE_Q5_K_R4: break; + case GGML_TYPE_Q6_K_R4: break; + case GGML_TYPE_IQ2_K_R4: break; + case GGML_TYPE_IQ3_K_R4: break; + case GGML_TYPE_IQ4_K_R4: break; + case GGML_TYPE_IQ5_K_R4: break; + case GGML_TYPE_IQ4_KS_R4: break; + case GGML_TYPE_Q8_K_R8: break; + case GGML_TYPE_BF16_R16: break; case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: { @@ -15215,6 +15165,8 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_I64: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_BN_R4: + case GGML_TYPE_I2_S: // nothing to validate break; default: diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index a40a6d37..b6d69011 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -33,7 +33,6 @@ void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_REST void quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k); void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k); void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k); -void quantize_row_q8_K64_ref(const float * GGML_RESTRICT x, block_q8_K64 * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_xxs_ref(const float * GGML_RESTRICT x, block_iq2_xxs * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_xs_ref (const float * GGML_RESTRICT x, block_iq2_xs * GGML_RESTRICT y, int64_t k); @@ -43,7 +42,6 @@ void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGM void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); void quantize_row_iq1_bn_ref (const float * GGML_RESTRICT x, block_iq1_bn * GGML_RESTRICT y, int64_t k); -void quantize_row_iq2_bn_ref (const float * GGML_RESTRICT x, block_iq2_bn * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -59,7 +57,6 @@ void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_q8_K64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -69,7 +66,6 @@ void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_iq1_bn (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_iq2_bn (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); // Dequantization void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -97,7 +93,6 @@ void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_ void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_iq1_bn (const block_iq1_bn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_iq2_bn (const block_iq2_bn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); // Dot product void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -123,7 +118,6 @@ void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq1_bn_q8_K64(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_iq2_bn_q8_K64(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); @@ -136,7 +130,6 @@ size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT ds size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_iq1_bn (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_iq2_bn (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e3047e1b..38150f3f 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -714,8 +714,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q4_0, .from_float_ref = (ggml_from_float_t) quantize_row_q4_0_ref, .vec_dot = ggml_vec_dot_q4_0_q8_0, -#if GGML_USE_IQK_MULMAT && defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1, +#if GGML_USE_IQK_MULMAT +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_0_X4, +#endif #else .vec_dot_type = GGML_TYPE_Q8_0, #endif @@ -735,7 +739,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q4_1, .from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref, .vec_dot = ggml_vec_dot_q4_1_q8_1, +#if GGML_USE_IQK_MULMAT + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else .vec_dot_type = GGML_TYPE_Q8_1, +#endif #if defined (__ARM_FEATURE_MATMUL_INT8) .nrows = 2, #else @@ -778,8 +786,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q5_0, .from_float_ref = (ggml_from_float_t) quantize_row_q5_0_ref, .vec_dot = ggml_vec_dot_q5_0_q8_0, -#if GGML_USE_IQK_MULMAT && defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1, +#if GGML_USE_IQK_MULMAT +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_0_X4, +#endif #else .vec_dot_type = GGML_TYPE_Q8_0, #endif @@ -795,7 +807,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q5_1, .from_float_ref = (ggml_from_float_t) quantize_row_q5_1_ref, .vec_dot = ggml_vec_dot_q5_1_q8_1, +#if GGML_USE_IQK_MULMAT + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else .vec_dot_type = GGML_TYPE_Q8_1, +#endif .nrows = 1, .row_meta_size = 0, }, @@ -808,8 +824,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q6_0, .from_float_ref = (ggml_from_float_t) quantize_row_q6_0_ref, .vec_dot = ggml_vec_dot_q6_0_q8_0, -#if GGML_USE_IQK_MULMAT && defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1, +#if GGML_USE_IQK_MULMAT +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_0_X4, +#endif #else .vec_dot_type = GGML_TYPE_Q8_0, #endif @@ -826,8 +846,16 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float_ref = (ggml_from_float_t) quantize_row_q8_0_ref, .from_float_to_mat = quantize_mat_q8_0, .vec_dot = ggml_vec_dot_q8_0_q8_0, -#if GGML_USE_IQK_MULMAT && defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1, +#if GGML_USE_IQK_MULMAT +#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) + // Remember: we cannot add 128 to the Q8 quants and use iblock sum in Q8_1 to subtract as we do on Zen4 for pure AVX2 + // because there the result of the _mm256_maddubs_epi16() instruction may overflow the int16_t range + // (and it gets satured if it does), leading to wrong results. + // TODO: expose HAVE_FANCY_SIMD from iqk_mul_mat.cpp and use #ifdef HAVE_FANCY_SIMD instead of the above. + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_0_X4, +#endif #else .vec_dot_type = GGML_TYPE_Q8_0, #endif @@ -849,6 +877,26 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_Q8_0_X4] = { + .type_name = "q8_0_x4", + .blck_size = QK8_0, + .type_size = sizeof(block_q8_0), + .is_quantized = true, + .from_float = quantize_row_q8_0_x4, + .from_float_ref = quantize_row_q8_0_x4, + .nrows = 1, + .row_meta_size = 0, + }, + [GGML_TYPE_Q8_1_X4] = { + .type_name = "q8_1_x4", + .blck_size = QK8_1, + .type_size = sizeof(block_q8_1), + .is_quantized = true, + .from_float = quantize_row_q8_1_x4, + .from_float_ref = quantize_row_q8_1_x4, + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_Q2_K] = { .type_name = "q2_K", .blck_size = QK_K, @@ -862,6 +910,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_Q2_K_R4] = { + .type_name = "q2_k_r4", + .blck_size = QK_K, + .type_size = sizeof(block_q2_K), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q2_k_r4, + .from_float = quantize_row_q2_k_r4, + .from_float_ref = (ggml_from_float_t) quantize_row_q2_k_r4_ref, + .vec_dot = vec_dot_q2_k_r4_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_Q3_K] = { .type_name = "q3_K", .blck_size = QK_K, @@ -875,6 +936,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_Q3_K_R4] = { + .type_name = "q3_k_r4", + .blck_size = QK_K, + .type_size = sizeof(block_q3_K), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q3_k_r4, + .from_float = quantize_row_q3_k_r4, + .from_float_ref = (ggml_from_float_t) quantize_row_q3_k_r4_ref, + .vec_dot = vec_dot_q3_k_r4_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_Q4_K] = { .type_name = "q4_K", .blck_size = QK_K, @@ -888,6 +962,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_Q4_K_R4] = { + .type_name = "q4_k_r4", + .blck_size = QK_K, + .type_size = sizeof(block_q4_K), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q4_k_r4, + .from_float = quantize_row_q4_k_r4, + .from_float_ref = (ggml_from_float_t) quantize_row_q4_k_r4_ref, + .vec_dot = vec_dot_q4_k_r4_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K32, + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_Q5_K] = { .type_name = "q5_K", .blck_size = QK_K, @@ -901,6 +988,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_Q5_K_R4] = { + .type_name = "q5_k_r4", + .blck_size = QK_K, + .type_size = sizeof(block_q5_K), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q5_k_r4, + .from_float = quantize_row_q5_k_r4, + .from_float_ref = (ggml_from_float_t) quantize_row_q5_k_r4_ref, + .vec_dot = vec_dot_q5_k_r4_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K32, + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_Q6_K] = { .type_name = "q6_K", .blck_size = QK_K, @@ -914,6 +1014,32 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_Q6_K_R4] = { + .type_name = "q6_k_r4", + .blck_size = QK_K, + .type_size = sizeof(block_q6_K), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q6_k_r4, + .from_float = quantize_row_q6_k_r4, + .from_float_ref = (ggml_from_float_t) quantize_row_q6_k_r4_ref, + .vec_dot = vec_dot_q6_k_r4_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + .row_meta_size = 0, + }, + [GGML_TYPE_Q8_K_R8] = { + .type_name = "q8_k_r8", + .blck_size = QK_K, + .type_size = sizeof(block_q8_k_r8)/8, + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q8_k_r8, + .from_float = quantize_row_q8_k_r8, + .from_float_ref = (ggml_from_float_t) quantize_row_q8_k_r8_ref, + .vec_dot = vec_dot_q8_k_r8_q8_k, + .vec_dot_type = GGML_TYPE_Q8_KR8, + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_IQ2_XXS] = { .type_name = "iq2_xxs", .blck_size = QK_K, @@ -927,6 +1053,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_IQ2_XXS_R4] = { + .type_name = "iq2_xxs_r4", + .blck_size = QK_K, + .type_size = sizeof(block_iq2_xxs), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq2_xxs_r4, + .from_float = quantize_row_iq2_xxs_r4, + .from_float_ref = (ggml_from_float_t)quantize_row_iq2_xxs_r4_ref, + .vec_dot = vec_dot_iq2_xxs_r4_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_IQ2_XS] = { .type_name = "iq2_xs", .blck_size = QK_K, @@ -940,6 +1079,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_IQ2_XS_R4] = { + .type_name = "iq2_xs_r4", + .blck_size = QK_K, + .type_size = sizeof(block_iq2_xs), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq2_xs_r4, + .from_float = quantize_row_iq2_xs_r4, + .from_float_ref = (ggml_from_float_t)quantize_row_iq2_xs_r4_ref, + .vec_dot = vec_dot_iq2_xs_r4_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_IQ3_XXS] = { .type_name = "iq3_xxs", .blck_size = QK_K, @@ -953,6 +1105,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_IQ3_XXS_R4] = { + .type_name = "iq3_xxs_r4", + .blck_size = QK_K, + .type_size = sizeof(block_iq3_xxs), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq3_xxs_r4, + .from_float = quantize_row_iq3_xxs_r4, + .from_float_ref = (ggml_from_float_t)quantize_row_iq3_xxs_r4_ref, + .vec_dot = vec_dot_iq3_xxs_r4_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_IQ3_S] = { .type_name = "iq3_s", .blck_size = QK_K, @@ -966,6 +1131,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_IQ3_S_R4] = { + .type_name = "iq3_s_r4", + .blck_size = QK_K, + .type_size = sizeof(block_iq3_s), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq3_s_r4, + .from_float = quantize_row_iq3_s_r4, + .from_float_ref = (ggml_from_float_t)quantize_row_iq3_s_r4_ref, + .vec_dot = vec_dot_iq3_s_r4_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_IQ2_S] = { .type_name = "iq2_s", .blck_size = QK_K, @@ -979,6 +1157,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_IQ2_S_R4] = { + .type_name = "iq2_s_r4", + .blck_size = QK_K, + .type_size = sizeof(block_iq2_s), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq2_s_r4, + .from_float = quantize_row_iq2_s_r4, + .from_float_ref = (ggml_from_float_t)quantize_row_iq2_s_r4_ref, + .vec_dot = vec_dot_iq2_s_r4_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_IQ1_S] = { .type_name = "iq1_s", .blck_size = QK_K, @@ -1026,11 +1217,24 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .to_float = (ggml_to_float_t) dequantize_row_iq2_bn, .from_float = quantize_row_iq2_bn, .from_float_ref = (ggml_from_float_t)quantize_row_iq2_bn_ref, - .vec_dot = ggml_vec_dot_iq2_bn_q8_K64, + .vec_dot = vec_dot_iq2_bn_q8_K64, .vec_dot_type = GGML_TYPE_Q8_K64, .nrows = 1, .row_meta_size = 4, }, + [GGML_TYPE_IQ2_BN_R4] = { + .type_name = "iq2_bn_r4", + .blck_size = QK_IQ1BN, + .type_size = sizeof(block_iq2_bn), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq2_bn_r4, + .from_float = quantize_row_iq2_bn_r4, + .from_float_ref = (ggml_from_float_t)quantize_row_iq2_bn_r4_ref, + .vec_dot = vec_dot_iq2_bn_r4_q8_K64, + .vec_dot_type = GGML_TYPE_Q8_K16, + .nrows = 1, + .row_meta_size = 4, + }, [GGML_TYPE_IQ4_NL] = { .type_name = "iq4_nl", .blck_size = QK4_NL, @@ -1040,8 +1244,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_iq4_nl, .from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref, .vec_dot = ggml_vec_dot_iq4_nl_q8_0, -#if GGML_USE_IQK_MULMAT && defined __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_1, +#if GGML_USE_IQK_MULMAT +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_0_X4, +#endif #else .vec_dot_type = GGML_TYPE_Q8_0, #endif @@ -1074,6 +1282,23 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 4, }, + [GGML_TYPE_IQ4_KS_R4] = { + .type_name = "iq4_ks_r4", + .blck_size = QK_K, + .type_size = sizeof(block_iq4_ks), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq4_ks_r4, + .from_float = quantize_row_iq4_ks_r4, + .from_float_ref = (ggml_from_float_t)quantize_row_iq4_ks_r4_ref, + .vec_dot = vec_dot_iq4_ks_r4_q8_k, +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_K32, +#else + .vec_dot_type = GGML_TYPE_Q8_K, +#endif + .nrows = 1, + .row_meta_size = 4, + }, [GGML_TYPE_IQ4_KSS] = { .type_name = "iq4_kss", .blck_size = QK_K, @@ -1103,6 +1328,30 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q8_K64, .row_meta_size = 0, }, + [GGML_TYPE_Q8_K16] = { + .type_name = "q8_K16", + .blck_size = 64, + .type_size = 64, + .is_quantized = true, + .from_float = quantize_row_q8_K16, + .row_meta_size = 20, + }, + [GGML_TYPE_Q8_K32] = { + .type_name = "q8_K32", + .blck_size = QK_K, + .type_size = sizeof(block_q8_K), + .is_quantized = true, + .from_float = quantize_row_q8_K32, + .row_meta_size = 0, + }, + [GGML_TYPE_Q8_KR8] = { + .type_name = "q8_KR8", + .blck_size = QK_K, + .type_size = sizeof(block_q8_K), + .is_quantized = true, + .from_float = quantize_row_q8_KR8, + .row_meta_size = 0, + }, [GGML_TYPE_BF16] = { .type_name = "bf16", .blck_size = 1, @@ -1116,6 +1365,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_BF16_R16] = { + .type_name = "bf16_r16", + .blck_size = 1, + .type_size = sizeof(ggml_bf16_t), + .is_quantized = false, + //.to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row, + //.from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row, + //.from_float_ref = (ggml_from_float_t) ggml_fp32_to_bf16_row_ref, + //.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16, + .vec_dot_type = GGML_TYPE_BF16, + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_Q4_0_4_4] = { .type_name = "q4_0_4x4", .blck_size = QK4_0, @@ -1180,6 +1442,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_IQ2_K_R4] = { + .type_name = "iq2_k_r4", + .blck_size = QK_K, + .type_size = sizeof(block_iq2_k), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq2_k_r4, + .from_float = quantize_row_iq2_k_r4, + .from_float_ref = (ggml_from_float_t)quantize_row_iq2_k_r4_ref, + .vec_dot = vec_dot_iq2_k_r4_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_IQ2_KS] = { .type_name = "iq2_ks", .blck_size = QK_K, @@ -1258,6 +1533,32 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_IQ4_K_R4] = { + .type_name = "iq4_k_r4", + .blck_size = QK_K, + .type_size = sizeof(block_iq4_k), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq4_k_r4, + .from_float = quantize_row_iq4_k_r4, + .from_float_ref = (ggml_from_float_t)quantize_row_iq4_k_r4_ref, + .vec_dot = vec_dot_iq4_k_r4_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + .row_meta_size = 0, + }, + [GGML_TYPE_IQ3_K_R4] = { + .type_name = "iq3_k_r4", + .blck_size = QK_K, + .type_size = sizeof(block_iq3_k), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq3_k_r4, + .from_float = quantize_row_iq3_k_r4, + .from_float_ref = (ggml_from_float_t)quantize_row_iq3_k_r4_ref, + .vec_dot = vec_dot_iq3_k_r4_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_IQ5_K] = { .type_name = "iq5_k", .blck_size = QK_K, @@ -1271,6 +1572,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_IQ5_K_R4] = { + .type_name = "iq5_k_r4", + .blck_size = QK_K, + .type_size = sizeof(block_iq5_k), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq5_k_r4, + .from_float = quantize_row_iq5_k_r4, + .from_float_ref = (ggml_from_float_t)quantize_row_iq5_k_r4_ref, + .vec_dot = vec_dot_iq5_k_r4_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + .row_meta_size = 0, + }, [GGML_TYPE_IQ6_K] = { .type_name = "iq6_k", .blck_size = QK_K, @@ -1284,6 +1598,137 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .nrows = 1, .row_meta_size = 0, }, + [GGML_TYPE_IQ4_NL_R4] = { + .type_name = "iq4_nl_r4", + .blck_size = QK4_NL, + .type_size = sizeof(block_iq4_nl), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq4_nl_r4, + .from_float = quantize_row_iq4_nl_r4, + .from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_r4_ref, + .vec_dot = vec_dot_iq4_nl_r4_q8_0, +#if GGML_USE_IQK_MULMAT +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_0_X4, +#endif +#else + .vec_dot_type = GGML_TYPE_Q8_0, +#endif + .nrows = 1, + .row_meta_size = 0, + }, + [GGML_TYPE_IQ4_XS_R4] = { + .type_name = "iq4_xs_r4", + .blck_size = QK_K, + .type_size = sizeof(block_iq4_xs), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq4_xs_r4, + .from_float = quantize_row_iq4_xs_r4, + .from_float_ref = (ggml_from_float_t)quantize_row_iq4_xs_r4_ref, + .vec_dot = vec_dot_iq4_xs_r4_q8_k, + .vec_dot_type = GGML_TYPE_Q8_K32, + .nrows = 1, + .row_meta_size = 0, + }, + [GGML_TYPE_Q4_0_R4] = { + .type_name = "q4_0_r4", + .blck_size = QK4_NL, + .type_size = sizeof(block_iq4_nl), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q4_0_r4, + .from_float = quantize_row_q4_0_r4, + .from_float_ref = (ggml_from_float_t)quantize_row_q4_0_r4_ref, + .vec_dot = vec_dot_q4_0_r4_q8_0, +#if GGML_USE_IQK_MULMAT +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_0_X4, +#endif +#else + .vec_dot_type = GGML_TYPE_Q8_0, +#endif + .nrows = 1, + .row_meta_size = 0, + }, + [GGML_TYPE_Q8_0_R4] = { + .type_name = "q8_0_r4", + .blck_size = QK8_0, + .type_size = sizeof(block_q8_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q8_0_r4, + .from_float = quantize_row_q8_0_r4, + .from_float_ref = (ggml_from_float_t)quantize_row_q8_0_r4_ref, + .vec_dot = vec_dot_q8_0_r4_q8_0, +#if GGML_USE_IQK_MULMAT +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_0_X4, +#endif +#else + .vec_dot_type = GGML_TYPE_Q8_0, +#endif + .nrows = 1, + .row_meta_size = 0, + }, + [GGML_TYPE_Q5_0_R4] = { + .type_name = "q5_0_r4", + .blck_size = QK5_0, + .type_size = sizeof(block_q5_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q5_0_r4, + .from_float = quantize_row_q5_0_r4, + .from_float_ref = (ggml_from_float_t)quantize_row_q5_0_r4_ref, + .vec_dot = vec_dot_q5_0_r4_q8_0, +#if GGML_USE_IQK_MULMAT +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_0_X4, +#endif +#else + .vec_dot_type = GGML_TYPE_Q8_0, +#endif + .nrows = 1, + .row_meta_size = 0, + }, + [GGML_TYPE_Q6_0_R4] = { + .type_name = "q6_0_r4", + .blck_size = QK6_0, + .type_size = sizeof(block_q6_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q6_0_r4, + .from_float = quantize_row_q6_0_r4, + .from_float_ref = (ggml_from_float_t)quantize_row_q6_0_r4_ref, + .vec_dot = vec_dot_q6_0_r4_q8_0, +#if GGML_USE_IQK_MULMAT +#if defined __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_1_X4, +#else + .vec_dot_type = GGML_TYPE_Q8_0_X4, +#endif +#else + .vec_dot_type = GGML_TYPE_Q8_0, +#endif + .nrows = 1, + .row_meta_size = 0, + }, + [GGML_TYPE_I2_S] = { + .type_name = "i2_s", + .blck_size = 1, + .type_size = 1, + .is_quantized = true, + .to_float = dequantize_row_ms_i2s, + .from_float = NULL, + .from_float_ref = NULL, + .vec_dot = NULL, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + .row_meta_size = 0, + }, }; // For internal test use @@ -3809,6 +4254,10 @@ GGML_CALL size_t ggml_nbytes(const struct ggml_tensor * tensor) { for (int i = 0; i < GGML_MAX_DIMS; ++i) { nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; } + // hack for I2_S + if(tensor->type == GGML_TYPE_I2_S) { + nbytes = nbytes / 4 + 32; + } } else { nbytes = tensor->nb[1]; //tensor->ne[0]*tensor->nb[0]/blck_size; @@ -3923,6 +4372,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break; case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break; case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break; + case GGML_FTYPE_MOSTLY_BF16_R16: wtype = GGML_TYPE_BF16_R16;break; case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break; case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break; case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break; @@ -3930,32 +4380,55 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q6_0: wtype = GGML_TYPE_Q6_0; break; case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break; - case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break; - case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break; - case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break; - case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break; + case GGML_FTYPE_MOSTLY_Q2_K_R4: wtype = GGML_TYPE_Q2_K_R4; break; + case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break; + case GGML_FTYPE_MOSTLY_Q3_K_R4: wtype = GGML_TYPE_Q3_K_R4; break; + case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break; + case GGML_FTYPE_MOSTLY_Q4_K_R4: wtype = GGML_TYPE_Q4_K_R4; break; + case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break; + case GGML_FTYPE_MOSTLY_Q5_K_R4: wtype = GGML_TYPE_Q5_K_R4; break; + case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break; + case GGML_FTYPE_MOSTLY_Q6_K_R4: wtype = GGML_TYPE_Q6_K_R4; break; + case GGML_FTYPE_MOSTLY_Q8_K_R8: wtype = GGML_TYPE_Q8_K_R8; break; case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break; + case GGML_FTYPE_MOSTLY_IQ2_XXS_R4: wtype = GGML_TYPE_IQ2_XXS_R4;break; case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break; + case GGML_FTYPE_MOSTLY_IQ2_XS_R4: wtype = GGML_TYPE_IQ2_XS_R4;break; case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break; + case GGML_FTYPE_MOSTLY_IQ3_XXS_R4: wtype = GGML_TYPE_IQ3_XXS_R4;break; case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break; case GGML_FTYPE_MOSTLY_IQ1_M: wtype = GGML_TYPE_IQ1_M; break; case GGML_FTYPE_MOSTLY_IQ1_BN: wtype = GGML_TYPE_IQ1_BN; break; case GGML_FTYPE_MOSTLY_IQ2_BN: wtype = GGML_TYPE_IQ2_BN; break; + case GGML_FTYPE_MOSTLY_IQ2_BN_R4: wtype = GGML_TYPE_IQ2_BN_R4;break; case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break; + case GGML_FTYPE_MOSTLY_IQ4_NL_R4: wtype = GGML_TYPE_IQ4_NL_R4;break; + case GGML_FTYPE_MOSTLY_IQ4_XS_R4: wtype = GGML_TYPE_IQ4_XS_R4;break; + case GGML_FTYPE_MOSTLY_Q4_0_R4: wtype = GGML_TYPE_Q4_0_R4; break; + case GGML_FTYPE_MOSTLY_Q5_0_R4: wtype = GGML_TYPE_Q5_0_R4; break; + case GGML_FTYPE_MOSTLY_Q6_0_R4: wtype = GGML_TYPE_Q6_0_R4; break; + case GGML_FTYPE_MOSTLY_Q8_0_R4: wtype = GGML_TYPE_Q8_0_R4; break; case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; case GGML_FTYPE_MOSTLY_IQ4_KS: wtype = GGML_TYPE_IQ4_KS; break; + case GGML_FTYPE_MOSTLY_IQ4_KS_R4: wtype = GGML_TYPE_IQ4_KS_R4;break; case GGML_FTYPE_MOSTLY_IQ4_KSS: wtype = GGML_TYPE_IQ4_KSS; break; case GGML_FTYPE_MOSTLY_IQ2_K: wtype = GGML_TYPE_IQ2_K; break; + case GGML_FTYPE_MOSTLY_IQ2_K_R4: wtype = GGML_TYPE_IQ2_K_R4; break; case GGML_FTYPE_MOSTLY_IQ2_KS: wtype = GGML_TYPE_IQ2_KS; break; case GGML_FTYPE_MOSTLY_IQ2_KT: wtype = GGML_TYPE_IQ2_KT; break; case GGML_FTYPE_MOSTLY_IQ3_KT: wtype = GGML_TYPE_IQ3_KT; break; case GGML_FTYPE_MOSTLY_IQ4_KT: wtype = GGML_TYPE_IQ4_KT; break; case GGML_FTYPE_MOSTLY_IQ3_K: wtype = GGML_TYPE_IQ3_K; break; case GGML_FTYPE_MOSTLY_IQ4_K: wtype = GGML_TYPE_IQ4_K; break; + case GGML_FTYPE_MOSTLY_IQ3_K_R4: wtype = GGML_TYPE_IQ3_K_R4; break; + case GGML_FTYPE_MOSTLY_IQ4_K_R4: wtype = GGML_TYPE_IQ4_K_R4; break; case GGML_FTYPE_MOSTLY_IQ5_K: wtype = GGML_TYPE_IQ5_K; break; + case GGML_FTYPE_MOSTLY_IQ5_K_R4: wtype = GGML_TYPE_IQ5_K_R4; break; case GGML_FTYPE_MOSTLY_IQ6_K: wtype = GGML_TYPE_IQ6_K; break; case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break; + case GGML_FTYPE_MOSTLY_IQ3_S_R4: wtype = GGML_TYPE_IQ3_S_R4; break; case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break; + case GGML_FTYPE_MOSTLY_IQ2_S_R4: wtype = GGML_TYPE_IQ2_S_R4; break; case GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = GGML_TYPE_Q4_0_4_4; break; case GGML_FTYPE_MOSTLY_Q4_0_4_8: wtype = GGML_TYPE_Q4_0_4_8; break; case GGML_FTYPE_MOSTLY_Q4_0_8_8: wtype = GGML_TYPE_Q4_0_8_8; break; @@ -10456,32 +10929,56 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q2_K: + case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_K_R4: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q4_K_R4: case GGML_TYPE_Q5_K: + case GGML_TYPE_Q5_K_R4: case GGML_TYPE_Q6_K: + case GGML_TYPE_Q6_K_R4: + case GGML_TYPE_Q8_K_R8: case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_XS_R4: case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_XXS_R4: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_BN_R4: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_NL_R4: + case GGML_TYPE_IQ4_XS_R4: + case GGML_TYPE_Q4_0_R4: + case GGML_TYPE_Q5_0_R4: + case GGML_TYPE_Q6_0_R4: + case GGML_TYPE_I2_S: + case GGML_TYPE_Q8_0_R4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: + case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ4_KSS: case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ3_K_R4: + case GGML_TYPE_IQ4_K_R4: case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ2_S_R4: case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: case GGML_TYPE_Q4_0_8_8: @@ -10900,33 +11397,59 @@ static void ggml_compute_forward_add1( case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_Q8_0_X4: + case GGML_TYPE_Q8_1_X4: case GGML_TYPE_Q2_K: + case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_K_R4: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q4_K_R4: case GGML_TYPE_Q5_K: + case GGML_TYPE_Q5_K_R4: case GGML_TYPE_Q6_K: + case GGML_TYPE_Q6_K_R4: + case GGML_TYPE_Q8_K_R8: case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_XS_R4: case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_XXS_R4: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_BN_R4: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_NL_R4: + case GGML_TYPE_IQ4_XS_R4: + case GGML_TYPE_Q4_0_R4: + case GGML_TYPE_Q5_0_R4: + case GGML_TYPE_Q6_0_R4: + case GGML_TYPE_I2_S: + case GGML_TYPE_Q8_0_R4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: + case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ4_KSS: case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ3_K_R4: + case GGML_TYPE_IQ4_K_R4: case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ2_S_R4: case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: case GGML_TYPE_Q4_0_8_8: @@ -11042,33 +11565,59 @@ static void ggml_compute_forward_acc( case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_Q8_0_X4: + case GGML_TYPE_Q8_1_X4: case GGML_TYPE_Q2_K: + case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_K_R4: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q4_K_R4: case GGML_TYPE_Q5_K: + case GGML_TYPE_Q5_K_R4: case GGML_TYPE_Q6_K: + case GGML_TYPE_Q6_K_R4: + case GGML_TYPE_Q8_K_R8: case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_XS_R4: case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_XXS_R4: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_BN_R4: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_NL_R4: + case GGML_TYPE_IQ4_XS_R4: + case GGML_TYPE_Q4_0_R4: + case GGML_TYPE_Q5_0_R4: + case GGML_TYPE_Q6_0_R4: + case GGML_TYPE_I2_S: + case GGML_TYPE_Q8_0_R4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: + case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ4_KSS: case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ3_K_R4: + case GGML_TYPE_IQ4_K_R4: case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ2_S_R4: case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: case GGML_TYPE_Q4_0_8_8: @@ -13467,6 +14016,14 @@ static void ggml_compute_forward_mul_mat_one_chunk( } } +static inline uint32_t simple_gcd(uint32_t a, uint32_t b) { + while (a != b) { + if (a > b) a -= b; + else b -= a; + } + return a; +} + static void ggml_compute_forward_mul_mat( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -13483,10 +14040,12 @@ static void ggml_compute_forward_mul_mat( enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float; - ggml_from_float_to_mat_t const from_float_to_mat = type_traits[vec_dot_type].from_float_to_mat; int64_t const vec_dot_num_rows = type_traits[type].nrows; int64_t const matmul_num_cols = type_traits[type].ncols; +#if !GGML_USE_IQK_MULMAT + ggml_from_float_to_mat_t const from_float_to_mat = type_traits[vec_dot_type].from_float_to_mat; int64_t const blck_size_interleave = type_traits[type].blck_size_interleave; +#endif ggml_gemv_t const gemv = type_traits[type].gemv; ggml_gemm_t const gemm = type_traits[type].gemm; @@ -13571,7 +14130,7 @@ UseGgmlGemm1:; char * wdata = (char *)params->wdata + params->wsize - params->qsize; if (strncmp(src1->name, wdata - GGML_MAX_NAME, GGML_MAX_NAME) == 0) { - goto AlreadyQunatized; + goto AlreadyQuantized; } wdata += GGML_MAX_NAME; @@ -13589,6 +14148,7 @@ UseGgmlGemm1:; for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { int64_t i11_processed = 0; +#if !GGML_USE_IQK_MULMAT if ((ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) { for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), @@ -13597,6 +14157,7 @@ UseGgmlGemm1:; } i11_processed = ne11 - ne11 % 4; } +#endif for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), @@ -13619,7 +14180,7 @@ UseGgmlGemm1:; //atomic_store(¶ms->shared->current_chunk, nth); } -AlreadyQunatized:; +AlreadyQuantized:; } const void * wdata = (src1->type == vec_dot_type) ? src1->data @@ -13627,14 +14188,31 @@ AlreadyQunatized:; #if GGML_USE_IQK_MULMAT if (src1->type != vec_dot_type && dst->type == GGML_TYPE_F32) { + // When K*Q and V*softmax(K*Q) (so ne12*ne13 > 1), it is better (faster) to have fewer threads processing + // one matrix multiplication, but work on several heads at once. + // Hence, we find the GCD(n12*ne13, nth) and have nth/GCD(n12*ne13, nth) threads per head. + // Leaving the previous version commented out for now just in case. const size_t row_size = ggml_row_size(vec_dot_type, ne10); - for (int64_t i13 = 0; i13 < ne13; i13++) - for (int64_t i12 = 0; i12 < ne12; i12++) - if (!iqk_mul_mat(ne01, ne11, ne00, - src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), - vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size, ///ggml_type_size(vec_dot_type), - (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), - ith, nth)) goto IQK_MulMat_Not_Available2; + int ntg = simple_gcd(ne12*ne13, nth); + int counter = 0; + for (int64_t i13 = 0; i13 < ne13; i13++) { + for (int64_t i12 = 0; i12 < ne12; i12++) { + if (counter++ % ntg == ith%ntg) { + if (!iqk_mul_mat(ne01, ne11, ne00, + src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), + vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size, ///ggml_type_size(vec_dot_type), + (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), + ith/ntg, nth/ntg)) goto IQK_MulMat_Not_Available2; + } + } + } + //for (int64_t i13 = 0; i13 < ne13; i13++) + // for (int64_t i12 = 0; i12 < ne12; i12++) + // if (!iqk_mul_mat(ne01, ne11, ne00, + // src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type), + // vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size, ///ggml_type_size(vec_dot_type), + // (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), + // ith, nth)) goto IQK_MulMat_Not_Available2; return; } IQK_MulMat_Not_Available2:; @@ -14231,32 +14809,56 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q2_K: + case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_K_R4: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q4_K_R4: case GGML_TYPE_Q5_K: + case GGML_TYPE_Q5_K_R4: case GGML_TYPE_Q6_K: + case GGML_TYPE_Q6_K_R4: + case GGML_TYPE_Q8_K_R8: case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_XS_R4: case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_XXS_R4: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_BN_R4: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_NL_R4: + case GGML_TYPE_IQ4_XS_R4: + case GGML_TYPE_Q4_0_R4: + case GGML_TYPE_Q5_0_R4: + case GGML_TYPE_Q6_0_R4: + case GGML_TYPE_I2_S: + case GGML_TYPE_Q8_0_R4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: + case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ4_KSS: case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ3_K_R4: + case GGML_TYPE_IQ4_K_R4: case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ2_S_R4: case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: case GGML_TYPE_Q4_0_8_8: @@ -14612,33 +15214,59 @@ static void ggml_compute_forward_set( case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_Q8_0_X4: + case GGML_TYPE_Q8_1_X4: case GGML_TYPE_Q2_K: + case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_K_R4: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q4_K_R4: case GGML_TYPE_Q5_K: + case GGML_TYPE_Q5_K_R4: case GGML_TYPE_Q6_K: + case GGML_TYPE_Q6_K_R4: + case GGML_TYPE_Q8_K_R8: case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_XS_R4: case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_XXS_R4: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_BN_R4: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_NL_R4: + case GGML_TYPE_IQ4_XS_R4: + case GGML_TYPE_Q4_0_R4: + case GGML_TYPE_Q5_0_R4: + case GGML_TYPE_Q6_0_R4: + case GGML_TYPE_I2_S: + case GGML_TYPE_Q8_0_R4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: + case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ4_KSS: case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ3_K_R4: + case GGML_TYPE_IQ4_K_R4: case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ2_S_R4: case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: case GGML_TYPE_Q4_0_8_8: @@ -14888,33 +15516,59 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_Q8_0_X4: + case GGML_TYPE_Q8_1_X4: case GGML_TYPE_Q2_K: + case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_K_R4: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q4_K_R4: case GGML_TYPE_Q5_K: + case GGML_TYPE_Q5_K_R4: case GGML_TYPE_Q6_K: + case GGML_TYPE_Q6_K_R4: + case GGML_TYPE_Q8_K_R8: case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_XS_R4: case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_XXS_R4: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_BN_R4: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_NL_R4: + case GGML_TYPE_IQ4_XS_R4: + case GGML_TYPE_Q4_0_R4: + case GGML_TYPE_Q5_0_R4: + case GGML_TYPE_Q6_0_R4: + case GGML_TYPE_I2_S: + case GGML_TYPE_Q8_0_R4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: + case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ4_KSS: case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ3_K_R4: + case GGML_TYPE_IQ4_K_R4: case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ2_S_R4: case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: case GGML_TYPE_Q4_0_8_8: @@ -15484,6 +16138,7 @@ static void ggml_compute_forward_clamp( } break; case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_BF16_R16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -15491,35 +16146,64 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_Q8_0_X4: + case GGML_TYPE_Q8_1_X4: case GGML_TYPE_Q2_K: + case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_K_R4: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q4_K_R4: case GGML_TYPE_Q5_K: + case GGML_TYPE_Q5_K_R4: case GGML_TYPE_Q6_K: + case GGML_TYPE_Q6_K_R4: + case GGML_TYPE_Q8_K_R8: + case GGML_TYPE_Q8_KR8: case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_XS_R4: case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_XXS_R4: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_BN: case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_BN_R4: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_NL_R4: + case GGML_TYPE_IQ4_XS_R4: + case GGML_TYPE_Q4_0_R4: + case GGML_TYPE_Q5_0_R4: + case GGML_TYPE_Q6_0_R4: + case GGML_TYPE_I2_S: + case GGML_TYPE_Q8_0_R4: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_KS: + case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ4_KSS: case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ3_K_R4: + case GGML_TYPE_IQ4_K_R4: case GGML_TYPE_IQ5_K: + case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ2_S_R4: case GGML_TYPE_Q8_K: case GGML_TYPE_Q8_K64: + case GGML_TYPE_Q8_K16: + case GGML_TYPE_Q8_K32: case GGML_TYPE_Q4_0_4_4: case GGML_TYPE_Q4_0_4_8: case GGML_TYPE_Q4_0_8_8: @@ -16850,32 +17534,40 @@ static void ggml_compute_forward_flash_attn_ext_f16( #if GGML_USE_IQK_MULMAT if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) { - int64_t work_per_slice = D*nek1*neq1; - int ntg = 1; - if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8; - else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4; - else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2; - if ((neq2*neq3)%(nth/ntg) == 0) { - //if (ith == 0) printf("%s: D = %d, neq2 = %d, neq1 = %d, nek1 = %d, ntg = %d, neq1/ntg = %d\n", __func__, - // (int)D, (int)neq2, (int)neq1, (int)nek1, ntg, (int)(neq1/ntg)); - int counter = 0; - for (int64_t iq3 = 0; iq3 < neq3; iq3++) { - for (int64_t iq2 = 0; iq2 < neq2; iq2++) { - if (counter++ % (nth/ntg) == ith/ntg) { - int iq1 = (ith%ntg)*neq1/ntg; - if (!iqk_flash_attn_noalibi(k->type, v->type, - D, neq1/ntg, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), - (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]), - (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]), - (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]), - (const void *)((const char *)mask->data + iq1*mask->nb[1]), - scale, softcap, - (float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable; - } + // I keep changing my mind what is the best strategy to split the threads when processing + // multiple heads. This is my current thinking, the commented out code below was the previous. + int ntg = nth/simple_gcd(neq2*neq3, nth); + int64_t neq1g = (neq1 + ntg - 1)/ntg; + //int64_t work_per_slice = D*nek1*neq1; + //int ntg = 1; + // + // When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix + // But we also want each thread to process the same amount of rows, so neq1 must be a multiple of + // the number of threads processing the (iq2, iq3) matrix. + // + //if (neq1 >= 8*nth) { + // if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8; + // else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4; + // else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2; + //} + int counter = 0; + for (int64_t iq3 = 0; iq3 < neq3; iq3++) { + for (int64_t iq2 = 0; iq2 < neq2; iq2++) { + if (counter++ % (nth/ntg) == ith/ntg) { + int iq1 = (ith%ntg)*neq1g; + int this_neq1 = MIN(neq1g, neq1-iq1); + if (!iqk_flash_attn_noalibi(k->type, v->type, + D, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float), + (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]), + (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]), + (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]), + (const void *)((const char *)mask->data + iq1*mask->nb[1]), + scale, softcap, + (float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable; } } - return; } + return; IQK_Flash_Attn_NotAvailable:; } @@ -20393,7 +21085,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type; if (node->src[1]->type != vec_dot_type) { - cur_q = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1])); + cur_q = ggml_row_size(vec_dot_type, node->src[1]->ne[0]) * ggml_nrows(node->src[1]); + //cur_q = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1])); } } break; case GGML_OP_MUL_MAT_ID: @@ -20403,7 +21096,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa const struct ggml_tensor * src1 = node->src[1]; const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type; if (src1->type != vec_dot_type) { - cur_q += ggml_row_size(vec_dot_type, ggml_nelements(src1)); + cur_q += ggml_row_size(vec_dot_type, node->src[1]->ne[0]) * ggml_nrows(node->src[1]); + //cur_q += ggml_row_size(vec_dot_type, ggml_nelements(src1)); } const int n_as = src0->ne[2]; cur_q += GGML_PAD(cur, sizeof(int64_t)); // align @@ -22254,12 +22948,17 @@ void ggml_quantize_init(enum ggml_type type) { ggml_critical_section_start(); switch (type) { + case GGML_TYPE_IQ2_XXS_R4: iq2xs_init_impl(GGML_TYPE_IQ2_XXS); break; + case GGML_TYPE_IQ2_XS_R4: iq2xs_init_impl(GGML_TYPE_IQ2_XS); break; + case GGML_TYPE_IQ2_S_R4: iq2xs_init_impl(GGML_TYPE_IQ2_S); break; case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: iq2xs_init_impl(type); break; + case GGML_TYPE_IQ3_XXS_R4: case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break; + case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ3_S: iq3xs_init_impl(512); break; default: // nothing break; @@ -22319,31 +23018,54 @@ size_t ggml_quantize_chunk( case GGML_TYPE_Q6_0: result = quantize_q6_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q2_K_R4: result = quantize_q2_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q3_K_R4: result = quantize_q3_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_K_R4: result = quantize_q4_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q5_K: result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q5_K_R4: result = quantize_q5_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q6_K_R4: result = quantize_q6_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q8_K_R8: result = quantize_q8_k_r8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ2_XXS_R4:result = quantize_iq2_xxs_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ2_XS_R4:result = quantize_iq2_xs_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ3_XXS_R4:result = quantize_iq3_xxs_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ3_S: result = quantize_iq3_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ3_S_R4:result = quantize_iq3_s_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_S: result = quantize_iq2_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ2_S_R4:result = quantize_iq2_s_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ1_BN: result = quantize_iq1_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_BN: result = quantize_iq2_bn (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ2_BN_R4:result = quantize_iq2_bn_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ4_NL_R4: result = quantize_iq4_nl_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ4_XS_R4: result = quantize_iq4_xs_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_0_R4: result = quantize_q4_0_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q5_0_R4: result = quantize_q5_0_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q6_0_R4: result = quantize_q6_0_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q8_0_R4: result = quantize_q8_0_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_KS: result = quantize_iq4_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ4_KS_R4:result = quantize_iq4_ks_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_KSS: result = quantize_iq4_kss(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_K: result = quantize_iq2_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ2_K_R4:result = quantize_iq2_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_KS: result = quantize_iq2_ks (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_KT: result = quantize_iq2_kt (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ3_KT: result = quantize_iq3_kt (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_KT: result = quantize_iq4_kt (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ3_K: result = quantize_iq3_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_K: result = quantize_iq4_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ3_K_R4:result = quantize_iq3_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ4_K_R4:result = quantize_iq4_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ5_K: result = quantize_iq5_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_IQ5_K_R4:result = quantize_iq5_k_r4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ6_K: result = quantize_iq6_k (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0_4_4: result = quantize_q4_0_4x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0_4_8: result = quantize_q4_0_4x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; @@ -22360,6 +23082,11 @@ size_t ggml_quantize_chunk( ggml_fp32_to_bf16_row_ref(src + start, (ggml_bf16_t *)dst + start, n); result = n * elemsize; } break; + case GGML_TYPE_BF16_R16: + { + repack_f32_bf16_r16(src + start, (char *) dst + start_row * row_size, nrows, n_per_row); + result = nrows * row_size; + } break; case GGML_TYPE_F32: { size_t elemsize = sizeof(float); diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index d7682e54..7ddaee2a 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -17,12 +17,14 @@ #include #include +#include #if defined IQK_IMPLEMENT #include "ggml-impl.h" #include "ggml-quants.h" #include "iqk_mul_mat.h" +#include "iqk_quantize.h" #define GGML_COMMON_IMPL_C #include "ggml-common.h" @@ -46,8 +48,57 @@ // For fp16/fp32 matri multiplications tiling is used to improve // performance. +#define FA_TIMING 0 + #include #include +#if FA_TIMING +#include +#include +struct Perf { + using TimePoint = std::chrono::time_point; + std::array times = {}; + std::mutex mutex; + bool report; + static auto cur_time() { return std::chrono::high_resolution_clock::now(); } + inline void accum(int what, const TimePoint& t1) { + auto t2 = cur_time(); + auto dt = delta(t1, t2); + std::lock_guard lock(mutex); + times[what] += dt; + } + inline void accum_nolock(int what, const TimePoint& t1) { + auto t2 = cur_time(); + auto dt = delta(t1, t2); + times[what] += dt; + } + inline void add(const Perf& other) { + std::lock_guard lock(mutex); + for (int i = 0; i < int(times.size()); ++i) times[i] += other.times[i]; + } + Perf(bool r) : report(r) {} + ~Perf() { + if (report) { + double tot = 0; + for (auto& t : times) tot += t; + if (!tot) return; + printf("======================= Timing: %g ms in total\n", tot); + for (int i = 0; i < int(times.size()); ++i) { + if (times[i]) { + printf("%d: %g ms -> %g%c\n", i, times[i], 100*times[i]/tot, '%'); + } + } + } + } + static Perf& instance() { + static Perf p(true); + return p; + } + static double delta(const TimePoint& t1, const TimePoint& t2) { + return 1e-6*std::chrono::duration_cast(t2-t1).count(); + } +}; +#endif #ifdef _MSC_VER #define IQK_NOINLINE __declspec(noinline) @@ -87,6 +138,24 @@ struct DataInfo { inline void store(int ix, int iy, float result) const { *(dst_row(iy) + ix) = result; } +#ifdef __AVX__ + inline void store(int ix, int iy, __m128 result) const { + _mm_storeu_ps(dst_row(iy) + ix, result); + } + inline void store(int ix, int iy, __m256 result) const { + _mm256_storeu_ps(dst_row(iy) + ix, result); + } +#endif +#ifdef __AVX512F__ + inline void store(int ix, int iy, __m512 result) const { + _mm512_storeu_ps(dst_row(iy) + ix, result); + } +#endif +#ifdef __ARM_NEON + inline void store(int ix, int iy, float32x4_t result) const { + vst1q_f32(dst_row(iy) + ix, result); + } +#endif inline float * dst_row(int iy) const { if (!row_mapping) return s + (cur_y + iy)*bs; int i12 = row_mapping[cur_y + iy].i2; @@ -100,21 +169,37 @@ typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& inf struct MulMat { std::array funcs = {}; + mul_mat_t func16 = nullptr; inline void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) { #ifdef __aarch64__ constexpr int k_x_step = 64; //8192; // Tiling does not seem to help on my M2 Max (but difference to tiling is small) #else constexpr int k_x_step = 64; // This works best on my Ryzen-7950X (but differences to other tile size are small) #endif + if (func16 && nrc_y >= 16) { + int n_step = (nrc_y - info.cur_y)/16; + for (int ix = 0; ix < nrc_x; ix += k_x_step) { + auto this_info = info; + this_info.s += ix; + int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; + for (int iy = 0; iy < n_step; ++iy) { + func16(n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x); + this_info.cur_y += 16; + } + } + info.cur_y += 16 * n_step; + if (info.cur_y == nrc_y) return; + } int ny = funcs.size(); while (!funcs[ny-1] && ny > 0) --ny; - int n_step = (nrc_y - info.cur_y)/ny; + int n_left = nrc_y - info.cur_y; + int n_step = n_left/ny; if (n_step > 0) { - if (n_step*ny != nrc_y) { + if (n_step*ny != n_left) { ++n_step; - int ny1 = nrc_y/n_step; + int ny1 = n_left/n_step; int ny2 = ny1 + 1; - int my1 = n_step*ny2 - nrc_y; + int my1 = n_step*ny2 - n_left; int my2 = n_step - my1; for (int ix = 0; ix < nrc_x; ix += k_x_step) { auto this_info = info; @@ -129,7 +214,7 @@ struct MulMat { this_info.cur_y += ny2; } } - info.cur_y += nrc_y; + info.cur_y += n_left; } else { for (int ix = 0; ix < nrc_x; ix += k_x_step) { @@ -144,12 +229,41 @@ struct MulMat { info.cur_y += ny * n_step; } } - int n_left = nrc_y - info.cur_y; + n_left = nrc_y - info.cur_y; if (n_left > 0) { funcs[n_left-1](n, vx, bx, info, nrc_x); } } static bool prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny); + static inline int num_rows(ggml_type type) { + switch (type) { + case GGML_TYPE_Q2_K_R4: + case GGML_TYPE_Q3_K_R4: + case GGML_TYPE_Q4_K_R4: + case GGML_TYPE_Q5_K_R4: + case GGML_TYPE_Q6_K_R4: + case GGML_TYPE_Q4_0_R4: + case GGML_TYPE_Q5_0_R4: + case GGML_TYPE_Q6_0_R4: + case GGML_TYPE_Q8_0_R4: + case GGML_TYPE_IQ4_NL_R4: + case GGML_TYPE_IQ4_XS_R4: + case GGML_TYPE_IQ2_K_R4: + case GGML_TYPE_IQ3_K_R4: + case GGML_TYPE_IQ4_K_R4: + case GGML_TYPE_IQ5_K_R4: + case GGML_TYPE_IQ4_KS_R4: + case GGML_TYPE_IQ2_XXS_R4: + case GGML_TYPE_IQ2_XS_R4: + case GGML_TYPE_IQ2_S_R4: + case GGML_TYPE_IQ3_XXS_R4: + case GGML_TYPE_IQ3_S_R4: + case GGML_TYPE_IQ2_BN_R4: return 4; + case GGML_TYPE_Q8_K_R8: return 8; + case GGML_TYPE_BF16_R16: return 16; + default: return 1; + } + } private: template static void set_functions(MulMat& m); }; @@ -170,13 +284,15 @@ bool iqk_mul_mat(long Nx, long Ny, long ne00, size_t row_size_qy = strideB; //*ggml_type_size(ggml_type(typeB)); //if (ith == 0) printf("%s: ne00 = %d, row_size_qx = %d, strideA = %d\n", __func__, int(ne00), int(row_size_qx), int(strideA)); - auto nrc_x = (Nx + nth - 1)/nth; + auto num_rows = MulMat::num_rows(ggml_type(typeA)); + GGML_ASSERT(Nx%num_rows == 0); + auto nrc_x = (Nx/num_rows + nth - 1)/nth; auto first_x = ith*nrc_x; - if (first_x + nrc_x > Nx) nrc_x = Nx - first_x; + if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x; - DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0}; + DataInfo info{C + first_x*num_rows, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0}; - mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny); + mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x*num_rows, row_size_qx, info, nrc_x*num_rows, Ny); return true; } @@ -192,11 +308,15 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) { return false; } - size_t row_size_qx = strideA; //*ggml_type_size(ggml_type(typeA)); - size_t row_size_qy = strideB; //*ggml_type_size(ggml_type(typeB)); - int nrc_x = (Nx + nth - 1)/nth; - int first_x = ith*nrc_x; - if (first_x + nrc_x > Nx) nrc_x = Nx - first_x; + size_t row_size_qx = strideA; + size_t row_size_qy = strideB; + auto num_rows = MulMat::num_rows(ggml_type(typeA)); + GGML_ASSERT(Nx%num_rows == 0); + auto nrc_x = (Nx/num_rows + nth - 1)/nth; + auto first_x = ith*nrc_x; + if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x; + first_x *= num_rows; + nrc_x *= num_rows; DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)}; mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny); @@ -308,6 +428,30 @@ template struct Q8 { const block_q8 * y[nrc_y]; }; +template struct Q8_16 { + + constexpr static int nrc_y = nrc; + + Q8_16(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto ptr = (const float *)info.src1_row(iy); + std::memcpy(d + 5*iy, ptr, 5*sizeof(float)); + y[iy] = (const int8_t *)(ptr + 5); + } + } + +#ifdef HAVE_FANCY_SIMD + inline __m512i load_quants64(int iy, int i) const { return _mm512_loadu_si512((const __m512i*)y[iy] + i); } +#endif + inline __m256i load_quants(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy] + i); } + inline float scale(int iy, int k) const { return d[5*iy+k]; } + inline float sum_row(int iy) const { return d[5*iy + 4]; } + inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 5*iy); } + + float d[5*nrc_y]; + const int8_t * y[nrc_y]; +}; + struct Scales8KBase { template inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const { @@ -2068,6 +2212,3333 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf #endif // Zen4 or vanilla AVX2 +template +static void mul_mat_iq2_bn_r4_q8_k16_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + if (nrc_x%4) { + printf("%s: %d is not a multiple of 4\n", __func__, nrc_x); + GGML_ABORT("fatal error"); + } + Q8_16 q8(info); + auto m3 = _mm256_set1_epi8(0x3); + auto m1 = _mm256_set1_epi16(1); + int nb = n / QK_IQ1BN; + __m256i qx[4]; + if constexpr (nrc_y > 4) { + __m256i acc[nrc_y] = {}; + __m128 sum4[nrc_y]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto dl = _mm_loadu_ps(dptr); + const uint8_t * iq2l = (const uint8_t *)(dptr + 4); + for (int ib = 0; ib < nb; ++ib) { + auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0); + qx[0] = _mm256_and_si256(bits, m3); + qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants(iy, 2*ib+0); + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + auto sumf1 = _mm256_cvtepi32_ps(acc[iy]); + auto s4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00))); + s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), s4); + sum4[iy] = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), s4); + acc[iy] = _mm256_setzero_si256(); + } + for (int ib = 0; ib < nb; ++ib) { + auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1); + qx[0] = _mm256_and_si256(bits, m3); + qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants(iy, 2*ib+1); + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + auto sumf1 = _mm256_cvtepi32_ps(acc[iy]); + auto s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4[iy]); + s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), s4); + info.store(ix, iy, s4); + acc[iy] = _mm256_setzero_si256(); + } + } + } else { + __m256i acc[2*nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto dl = _mm_loadu_ps(dptr); + const uint8_t * iq2l = (const uint8_t *)(dptr + 4); + for (int ib = 0; ib < nb; ++ib) { + auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0); + qx[0] = _mm256_and_si256(bits, m3); + qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants(iy, 2*ib+0); + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + acc[2*iy+0] = _mm256_add_epi32(acc[2*iy+0], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2))); + } + bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1); + qx[0] = _mm256_and_si256(bits, m3); + qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants(iy, 2*ib+1); + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + acc[2*iy+1] = _mm256_add_epi32(acc[2*iy+1], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + auto sumf1 = _mm256_cvtepi32_ps(acc[2*iy+0]); + auto sumf2 = _mm256_cvtepi32_ps(acc[2*iy+1]); + auto sum4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00))); + sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4); + sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4); + sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4); + sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4); + info.store(ix, iy, sum4); + acc[2*iy+0] = acc[2*iy+1] = _mm256_setzero_si256(); + } + } + } +} + +#ifdef HAVE_FANCY_SIMD +template +static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + if (nrc_x%4) { + printf("%s: %d is not a multiple of 4\n", __func__, nrc_x); + GGML_ABORT("fatal error"); + } + if constexpr (nrc_y == 1) { + mul_mat_iq2_bn_r4_q8_k16_avx2<1>(n, vx, bx, info, nrc_x); + } else { + Q8_16 q8(info); + auto m3 = _mm512_set1_epi8(0x3); + int nb = n / QK_IQ1BN; + __m512i acc[2*nrc_y] = {}; + __m512i qx[8]; + for (int ix = 0; ix < nrc_x/8; ++ix) { + const float * dptr1 = (const float *)((const char *)vx + (8*ix+0)*bx); + const float * dptr2 = (const float *)((const char *)vx + (8*ix+4)*bx); + auto dl = _mm_loadu_ps(dptr1); + auto dh = _mm_loadu_ps(dptr2); + const uint8_t * iq2l = (const uint8_t *)(dptr1 + 4); + const uint8_t * iq2h = (const uint8_t *)(dptr2 + 4); + for (int ib = 0; ib < nb; ++ib) { + auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib); + auto bits_h = _mm512_loadu_si512((const __m512i *)iq2h + ib); + qx[0] = _mm512_and_si512(bits_l, m3); + qx[1] = _mm512_and_si512(bits_h, m3); + qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3); + qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 2), m3); + qx[4] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3); + qx[5] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 4), m3); + qx[6] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3); + qx[7] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants64(iy, ib); + auto sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)); + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], sy); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], sy); + sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)); + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], sy); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], sy); + sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)); + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[4], sy); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[5], sy); + sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)); + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[6], sy); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[7], sy); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + __m128 sum4; + for (int k = 0; k < 2; ++k) { + const auto& dx = k == 0 ? dl : dh; + auto sumf = _mm512_cvtepi32_ps(acc[2*iy+k]); + sum4 = _mm_mul_ps (_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x00))); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x55)), sum4); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xaa)), sum4); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xff)), sum4); + sum4 = _mm_fmadd_ps(dx, _mm_set1_ps(-q8.sum_row(iy)), sum4); + info.store(8*ix + 4*k, iy, sum4); + } + acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512(); + } + } + if (int ix = 8*(nrc_x/8); ix < nrc_x) { + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto dl = _mm_loadu_ps(dptr); + const uint8_t * iq2l = (const uint8_t *)(dptr + 4); + for (int ib = 0; ib < nb; ++ib) { + auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib); + qx[0] = _mm512_and_si512(bits_l, m3); + qx[1] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3); + qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3); + qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants64(iy, ib); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + auto sumf = _mm512_cvtepi32_ps(acc[iy]); + auto sum4 = _mm_mul_ps(_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00))); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4); + sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4); + info.store(ix, iy, sum4); + } + } + } +} +#else +template +static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + if (nrc_x%4) { + printf("%s: %d is not a multiple of 4\n", __func__, nrc_x); + GGML_ABORT("fatal error"); + } + mul_mat_iq2_bn_r4_q8_k16_avx2(n, vx, bx, info, nrc_x); +} +#endif + +#ifdef HAVE_FANCY_SIMD +template +static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + Q8 q8(info); + auto m4 = _mm512_set1_epi8(0xf); + auto values = load_iq4nl_values_512(); + int nb = n / QK4_NL; + GGML_ASSERT(nb%4 == 0); + __m512 acc[2*nrc_y] = {}; + __m512i qx[4]; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_iq4_nl_r4 * iq4l = (const block_iq4_nl_r4 *)((const char *)vx + (ix+0)*bx); + const block_iq4_nl_r4 * iq4h = (const block_iq4_nl_r4 *)((const char *)vx + (ix+4)*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int k = 0; k < 4; ++k) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l[4*ib4+k].d)); + auto scales1 = _mm256_set_m128(scales128, scales128); + scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h[4*ib4+k].d)); + auto scales2 = _mm256_set_m128(scales128, scales128); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-64.f)); + auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+0)), + _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+0), 1); + auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+1)), + _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+1), 1); + qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4)); + qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4)); + qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4)); + qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]); + acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); + auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); + auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); + info.store(ix+0, iy, sum1); + info.store(ix+4, iy, sum2); + } + } +} +#else +template +static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto m4 = _mm256_set1_epi8(0xf); + auto m1 = _mm256_set1_epi16(1); + auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values); + auto values = MM256_SET_M128I(values128, values128); + int nb = n / QK4_NL; + GGML_ASSERT(nb%4 == 0); + __m256 acc[nrc_y] = {}; + //__m256 acc[2*nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq4_nl_r4 * iq4 = (const block_iq4_nl_r4 *)((const char *)vx + ix*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int k = 0; k < 4; ++k) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[4*ib4+k].d)); + auto scales = _mm256_set_m128(scales128, scales128); + auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+1); + auto q1 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4)); + auto q2 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4)); + auto q3 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4)); + auto q4 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4)); + auto s1 = _mm256_sign_epi8(q1, q1); + auto s2 = _mm256_sign_epi8(q2, q2); + auto s3 = _mm256_sign_epi8(q3, q3); + auto s4 = _mm256_sign_epi8(q4, q4); + + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); + auto sumi1 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q1))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q2)))); + auto sumi2 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q3))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q4)))); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]))); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), acc[iy]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + info.store(ix, iy, sum); + acc[iy] = _mm256_setzero_ps(); + } + } +} +#endif + +template +static void mul_mat_q4_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto m4 = _mm256_set1_epi8(0xf); + auto m1 = _mm256_set1_epi16(1); + int nb = n / QK4_NL; + GGML_ASSERT(nb%4 == 0); + __m256 acc[nrc_y] = {}; + float d8[8*nrc_y]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq4_nl_r4 * iq4 = (const block_iq4_nl_r4 *)((const char *)vx + ix*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)); + _mm256_storeu_ps(d8 + 8*iy, scales); + } + for (int k = 0; k < 4; ++k) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[4*ib4+k].d)); + auto scales = _mm256_set_m128(scales128, scales128); + auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-4.f)); + auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+1); + auto q1 = _mm256_and_si256(bits1, m4); + auto q2 = _mm256_and_si256(bits2, m4); + auto q3 = _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4); + auto q4 = _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(q1, _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(q2, _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(q3, _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(q4, _mm256_shuffle_epi32(y, 0xff))); + auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k])); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(d8[8*iy+4+k]), acc[iy]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + info.store(ix, iy, sum); + acc[iy] = _mm256_setzero_ps(); + } + } +} + +#ifdef HAVE_FANCY_SIMD +template +static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + if constexpr (nrc_y == 1) { + mul_mat_q4_0_r4_q8_1_avx2<1>(n, vx, bx, info, nrc_x); + return; + } + GGML_ASSERT(nrc_x%8 == 0); + Q8 q8(info); + auto m4 = _mm512_set1_epi8(0xf); + int nb = n / QK4_NL; + GGML_ASSERT(nb%4 == 0); + __m512 acc[2*nrc_y] = {}; + __m512i qx[4]; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_iq4_nl_r4 * iq4l = (const block_iq4_nl_r4 *)((const char *)vx + (ix+0)*bx); + const block_iq4_nl_r4 * iq4h = (const block_iq4_nl_r4 *)((const char *)vx + (ix+4)*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int k = 0; k < 4; ++k) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l[4*ib4+k].d)); + auto scales1 = _mm256_set_m128(scales128, scales128); + scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h[4*ib4+k].d)); + auto scales2 = _mm256_set_m128(scales128, scales128); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-4.f)); + auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+0)), + _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+0), 1); + auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[4*ib4+k].qs+1)), + _mm256_loadu_si256((const __m256i *)iq4h[4*ib4+k].qs+1), 1); + qx[0] = _mm512_and_si512(bits1, m4); + qx[1] = _mm512_and_si512(bits2, m4); + qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4); + qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]); + acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); + auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); + auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); + info.store(ix+0, iy, sum1); + info.store(ix+4, iy, sum2); + } + } +} +#else +template +static void mul_mat_q4_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + mul_mat_q4_0_r4_q8_1_avx2(n, vx, bx, info, nrc_x); +} +#endif + +template +static void mul_mat_q5_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + Q8 q8(info); + auto m4 = _mm256_set1_epi8(0xf); + auto m5 = _mm256_set1_epi8(0x10); + auto m1 = _mm256_set1_epi16(1); + int nb = n / QK5_0; + GGML_ASSERT(nb%4 == 0); + __m256 acc[nrc_y] = {}; + float d8[8*nrc_y]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_q5_0_r4 * iq5 = (const block_q5_0_r4 *)((const char *)vx + ix*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)); + _mm256_storeu_ps(d8 + 8*iy, scales); + } + for (int k = 0; k < 4; ++k) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5[4*ib4+k].d)); + auto scales = _mm256_set_m128(scales128, scales128); + auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-8.f)); + auto bits1 = _mm256_loadu_si256((const __m256i *)iq5[4*ib4+k].qs+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)iq5[4*ib4+k].qs+1); + auto hbits = _mm_loadu_si128((const __m128i *)iq5[4*ib4+k].qh); + auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 1), hbits); + auto q1 = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 4), m5)); + auto q2 = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 2), m5)); + auto q3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hb, m5)); + auto q4 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hb, 2), m5));; + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(q1, _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(q2, _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(q3, _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(q4, _mm256_shuffle_epi32(y, 0xff))); + auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k])); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + info.store(ix, iy, sum); + acc[iy] = _mm256_setzero_ps(); + } + } +} + +#ifdef HAVE_FANCY_SIMD +template +static void mul_mat_q5_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + if constexpr (nrc_y == 1) { + mul_mat_q5_0_r4_q8_1_avx2<1>(n, vx, bx, info, nrc_x); + } else { + GGML_ASSERT(nrc_x%8 == 0); + Q8 q8(info); + auto m4 = _mm512_set1_epi8(0xf); + auto m5 = _mm512_set1_epi8(0x10); + int nb = n / QK5_0; + GGML_ASSERT(nb%4 == 0); + __m512 acc[2*nrc_y] = {}; + __m512i qx[4]; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q5_0_r4 * iq5l = (const block_q5_0_r4 *)((const char *)vx + (ix+0)*bx); + const block_q5_0_r4 * iq5h = (const block_q5_0_r4 *)((const char *)vx + (ix+4)*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int k = 0; k < 4; ++k) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5l[4*ib4+k].d)); + auto scales1 = _mm256_set_m128(scales128, scales128); + scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5h[4*ib4+k].d)); + auto scales2 = _mm256_set_m128(scales128, scales128); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-8.f)); + auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[4*ib4+k].qs+0)), + _mm256_loadu_si256((const __m256i *)iq5h[4*ib4+k].qs+0), 1); + auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[4*ib4+k].qs+1)), + _mm256_loadu_si256((const __m256i *)iq5h[4*ib4+k].qs+1), 1); + auto hbits1 = _mm_loadu_si128((const __m128i *)iq5l[4*ib4+k].qh); + auto hbits2 = _mm_loadu_si128((const __m128i *)iq5h[4*ib4+k].qh); + auto hb1 = MM256_SET_M128I(_mm_srli_epi16(hbits1, 1), hbits1); + auto hb2 = MM256_SET_M128I(_mm_srli_epi16(hbits2, 1), hbits2); + auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hb1), hb2, 1); + qx[0] = _mm512_or_si512(_mm512_and_si512(bits1, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 4), m5)); + qx[1] = _mm512_or_si512(_mm512_and_si512(bits2, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 2), m5)); + //qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 2), m5); + qx[2] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4), _mm512_and_si512(hb, m5)); + qx[3] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4), _mm512_and_si512(_mm512_srli_epi16(hb, 2), m5)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]); + acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); + auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); + auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); + info.store(ix+0, iy, sum1); + info.store(ix+4, iy, sum2); + } + } + } +} +#else +template +static void mul_mat_q5_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + mul_mat_q5_0_r4_q8_1_avx2(n, vx, bx, info, nrc_x); +} +#endif + +template +static void mul_mat_q6_0_r4_q8_1_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + Q8 q8(info); + auto m4 = _mm256_set1_epi8(0xf); + auto m6 = _mm256_set1_epi8(0x30); +#ifndef HAVE_FANCY_SIMD + auto m1 = _mm256_set1_epi16(1); +#endif + int nb = n / QK6_0; + GGML_ASSERT(nb%4 == 0); + __m256 acc[nrc_y] = {}; + float d8[8*nrc_y]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_q6_0_r4 * iq6 = (const block_q6_0_r4 *)((const char *)vx + ix*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)); + _mm256_storeu_ps(d8 + 8*iy, scales); + } + for (int k = 0; k < 4; ++k) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6[4*ib4+k].d)); + auto scales = _mm256_set_m128(scales128, scales128); + auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-16.f)); + auto bits1 = _mm256_loadu_si256((const __m256i *)iq6[4*ib4+k].qs+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)iq6[4*ib4+k].qs+1); + auto hbits = _mm256_loadu_si256((const __m256i *)iq6[4*ib4+k].qh); + auto q1 = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 4), m6)); + auto q2 = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 2), m6)); + auto q3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hbits, m6)); + auto q4 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m6)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_dpbusd_epi32(_mm256_setzero_si256(), q1, _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, q2, _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, q3, _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, q4, _mm256_shuffle_epi32(y, 0xff)); +#else + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(q1, _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(q2, _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(q3, _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(q4, _mm256_shuffle_epi32(y, 0xff))); + auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2)); +#endif + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k])); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + info.store(ix, iy, sum); + acc[iy] = _mm256_setzero_ps(); + } + } +} + +#ifdef HAVE_FANCY_SIMD +template +static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + if constexpr (nrc_y == 1) { + mul_mat_q6_0_r4_q8_1_avx2<1>(n, vx, bx, info, nrc_x); + } else { + GGML_ASSERT(nrc_x%8 == 0); + Q8 q8(info); + auto m4 = _mm512_set1_epi8(0xf); + auto m6 = _mm512_set1_epi8(0x30); + int nb = n / QK6_0; + GGML_ASSERT(nb%4 == 0); + __m512 acc[2*nrc_y] = {}; + __m512i qx[4]; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q6_0_r4 * iq6l = (const block_q6_0_r4 *)((const char *)vx + (ix+0)*bx); + const block_q6_0_r4 * iq6h = (const block_q6_0_r4 *)((const char *)vx + (ix+4)*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int k = 0; k < 4; ++k) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6l[4*ib4+k].d)); + auto scales1 = _mm256_set_m128(scales128, scales128); + scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6h[4*ib4+k].d)); + auto scales2 = _mm256_set_m128(scales128, scales128); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-16.f)); + auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l[4*ib4+k].qs+0)), + _mm256_loadu_si256((const __m256i *)iq6h[4*ib4+k].qs+0), 1); + auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l[4*ib4+k].qs+1)), + _mm256_loadu_si256((const __m256i *)iq6h[4*ib4+k].qs+1), 1); + auto hbits1 = _mm256_loadu_si256((const __m256i *)iq6l[4*ib4+k].qh); + auto hbits2 = _mm256_loadu_si256((const __m256i *)iq6h[4*ib4+k].qh); + auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hbits1), hbits2, 1); + qx[0] = _mm512_and_si512(bits1, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 4), m6); + qx[1] = _mm512_and_si512(bits2, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 2), m6);; + qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4) | _mm512_and_si512(hb, m6); + qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4) | _mm512_and_si512(_mm512_srli_epi16(hb, 2), m6); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]); + acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); + auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); + auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); + info.store(ix+0, iy, sum1); + info.store(ix+4, iy, sum2); + } + } + } +} +#else +template +static void mul_mat_q6_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + mul_mat_q6_0_r4_q8_1_avx2(n, vx, bx, info, nrc_x); +} +#endif + +#ifdef HAVE_FANCY_SIMD +template +static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + Q8 q8(info); + int nb = n / QK8_0; + GGML_ASSERT(nb%4 == 0); + if constexpr (nrc_y == 1) { + auto m127 = _mm256_set1_epi8(127); + auto m1 = _mm256_set1_epi16(1); + __m256 acc[nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int k = 0; k < 4; ++k) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq8[4*ib4+k].d)); + auto scales = _mm256_set_m128(scales128, scales128); + auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-63.5f)); + auto q1 = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0), m127); + auto q2 = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1), m127); + auto q3 = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2), m127); + auto q4 = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3), m127); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); + auto sumi1 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(q1, _mm256_shuffle_epi32(y, 0x00))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(q2, _mm256_shuffle_epi32(y, 0x55)))); + auto sumi2 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(q3, _mm256_shuffle_epi32(y, 0xaa))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(q4, _mm256_shuffle_epi32(y, 0xff)))); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]))); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[iy]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + info.store(ix, iy, sum); + acc[iy] = _mm256_setzero_ps(); + } + } + } else { + __m512 acc[2*nrc_y] = {}; + __m512i qx[4]; + auto m127 = _mm512_set1_epi8(127); + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q8_0_x4 * q8l = (const block_q8_0_x4 *)((const char *)vx + (ix+0)*bx); + const block_q8_0_x4 * q8h = (const block_q8_0_x4 *)((const char *)vx + (ix+4)*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int k = 0; k < 4; ++k) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8l[4*ib4+k].d)); + auto scales1 = _mm256_set_m128(scales128, scales128); + scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8h[4*ib4+k].d)); + auto scales2 = _mm256_set_m128(scales128, scales128); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + auto scales_m = _mm512_mul_ps(scales, _mm512_set1_ps(-63.5f)); + qx[0] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+0)), + _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+0), 1); + qx[1] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+1)), + _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+1), 1); + qx[2] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+2)), + _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+2), 1); + qx[3] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+3)), + _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+3), 1); + qx[0] = _mm512_add_epi8(qx[0], m127); + qx[1] = _mm512_add_epi8(qx[1], m127); + qx[2] = _mm512_add_epi8(qx[2], m127); + qx[3] = _mm512_add_epi8(qx[3], m127); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + auto dy = _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum512 = _mm512_add_ps(acc[2*iy+0], acc[2*iy+1]); + acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); + auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); + auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); + info.store(ix+0, iy, sum1); + info.store(ix+4, iy, sum2); + } + } + } +} +#else +template +static void mul_mat_q8_0_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto m1 = _mm256_set1_epi16(1); + int nb = n / QK8_0; + GGML_ASSERT(nb%4 == 0); + __m256 acc[nrc_y] = {}; + float d8[4*nrc_y]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto scales = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d)); + _mm_storeu_ps(d8 + 4*iy, scales); + } + for (int k = 0; k < 4; ++k) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq8[4*ib4+k].d)); + auto scales = _mm256_set_m128(scales128, scales128); + auto q1 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+0); + auto q2 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+1); + auto q3 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+2); + auto q4 = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+3); + auto s1 = _mm256_sign_epi8(q1, q1); + auto s2 = _mm256_sign_epi8(q2, q2); + auto s3 = _mm256_sign_epi8(q3, q3); + auto s4 = _mm256_sign_epi8(q4, q4); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); + auto sumi1 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q1))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q2)))); + auto sumi2 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q3))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q4)))); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k])); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), acc[iy]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + info.store(ix, iy, sum); + acc[iy] = _mm256_setzero_ps(); + } + } +} +#endif + +template +static void mul_mat_iq4_xs_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto m4 = _mm256_set1_epi8(0xf); +#ifndef HAVE_FANCY_SIMD + auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100); + auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values); + auto values = MM256_SET_M128I(values128, values128); +#else + auto values = load_iq4nl_values_256(); +#endif + int nbl = n / QK_K; + using helper_t = union { __m256i vec; uint32_t val[8]; }; + helper_t h; + __m256 acc[nrc_y] = {}; + __m256i isum[nrc_y] = {}; + __m256i qx[4]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq4_xs_r4 * iq4 = (const block_iq4_xs_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[ibl].d)); + auto d4 = _mm256_set_m128(dl, dl); + auto slbits = _mm_loadu_si128((const __m128i *)iq4[ibl].scales_l); + auto sl = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(slbits, 4), slbits), _mm256_set1_epi8(0xf)); + auto aux64 = (const uint64_t *)iq4[ibl].scales_h; + auto shbits = _mm_set_epi64x(aux64[0] >> 2, aux64[0]); + auto sh = _mm256_and_si256(MM256_SET_M128I(shbits, _mm_slli_epi16(shbits, 4)), _mm256_set1_epi8(0x30)); + h.vec = _mm256_sub_epi8(_mm256_or_si256(sl, sh), _mm256_set1_epi8(32)); + for (int ib = 0; ib < QK_K/32; ++ib) { +#ifdef HAVE_FANCY_SIMD + auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi32(h.val[ib])); + auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales)); + auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-64.f)); + for (int iy = 0; iy < nrc_y; ++iy) { + float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib]; + acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]); + } +#else + auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi32(h.val[ib])), s_shuffle); +#endif + auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1); + qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4)); + qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4)); + qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4)); + qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4)); +#ifndef HAVE_FANCY_SIMD + auto s1 = _mm256_sign_epi8(qx[0], qx[0]); + auto s2 = _mm256_sign_epi8(qx[1], qx[1]); + auto s3 = _mm256_sign_epi8(qx[2], qx[2]); + auto s4 = _mm256_sign_epi8(qx[3], qx[3]); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi)); +#else + auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])); + auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])); + auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])); + auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4))); +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm256_setzero_si256(); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + acc[iy] = _mm256_setzero_ps(); + info.store(ix+0, iy, sum); + } + } +} + +#ifdef HAVE_FANCY_SIMD +template +static void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + if constexpr (nrc_y == 1){ + mul_mat_iq4_xs_r4_q8_k_avx2<1>(n, vx, bx, info, nrc_x); + } else { + GGML_ASSERT(nrc_x%8 == 0); + Q8 q8(info); + auto m4 = _mm512_set1_epi8(0xf); + auto values = load_iq4nl_values_512(); + int nbl = n / QK_K; + using helper_t = union { __m512i vec; uint32_t val[16]; }; + helper_t h; + __m512 acc[nrc_y] = {}; + __m512i isum[nrc_y] = {}; + __m512i qx[4]; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_iq4_xs_r4 * iq4l = (const block_iq4_xs_r4 *)((const char *)vx + (ix+0)*bx); + const block_iq4_xs_r4 * iq4h = (const block_iq4_xs_r4 *)((const char *)vx + (ix+4)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l[ibl].d)); + auto dh = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h[ibl].d)); + auto d4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(dl, dl)), _mm256_set_m128(dh, dh), 1); + auto d4x64 = _mm512_mul_ps(d4, _mm512_set1_ps(-64.f)); + auto slbits_l = _mm_loadu_si128((const __m128i *)iq4l[ibl].scales_l); + auto shbits_l = _mm_loadu_si128((const __m128i *)iq4h[ibl].scales_l); + auto sl_l = MM256_SET_M128I(_mm_srli_epi16(slbits_l, 4), slbits_l); + auto sh_l = MM256_SET_M128I(_mm_srli_epi16(shbits_l, 4), shbits_l); + auto slb = _mm512_and_si512(_mm512_inserti32x8(_mm512_castsi256_si512(sl_l), sh_l, 1), m4); + auto aux64 = (const uint64_t *)iq4l[ibl].scales_h; + auto slbits_h = _mm_set_epi64x(aux64[0] >> 2, aux64[0]); + aux64 = (const uint64_t *)iq4h[ibl].scales_h; + auto shbits_h = _mm_set_epi64x(aux64[0] >> 2, aux64[0]); + auto sl_h = MM256_SET_M128I(slbits_h, _mm_slli_epi16(slbits_h, 4)); + auto sh_h = MM256_SET_M128I(shbits_h, _mm_slli_epi16(shbits_h, 4)); + auto shb = _mm512_and_si512(_mm512_inserti32x8(_mm512_castsi256_si512(sl_h), sh_h, 1), _mm512_set1_epi8(0x30)); + h.vec = _mm512_sub_epi8(_mm512_or_si512(slb, shb), _mm512_set1_epi8(32)); + for (int ib = 0; ib < QK_K/32; ++ib) { + auto iscales = _mm512_cvtepi8_epi32(_mm_blend_epi32(_mm_set1_epi32(h.val[ib+0]), _mm_set1_epi32(h.val[ib+8]), 0x0c)); + auto scales = _mm512_cvtepi32_ps(iscales); + auto scales_m = _mm512_mul_ps(scales, d4x64); + auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+0)), + _mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+0), 1); + auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+1)), + _mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+1), 1); + qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4)); + qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4)); + qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4)); + qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + isum[iy] = _mm512_add_epi32(isum[iy], _mm512_mullo_epi32(iscales, sumi)); + float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib]; + acc[iy] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[iy]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d4, _mm512_set1_ps(q8.scale(iy, ibl))), _mm512_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm512_setzero_si512(); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 0), _mm512_extractf32x4_ps(acc[iy], 1)); + auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 2), _mm512_extractf32x4_ps(acc[iy], 3)); + info.store(ix+0, iy, sum1); + info.store(ix+4, iy, sum2); + acc[iy] = _mm512_setzero_ps(); + } + } + } +} +#else +template +static void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + mul_mat_iq4_xs_r4_q8_k_avx2(n, vx, bx, info, nrc_x); +} +#endif + +template +static void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto m4 = _mm256_set1_epi8(0xf); +#ifndef HAVE_FANCY_SIMD + auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100); + auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values); + auto values = MM256_SET_M128I(values128, values128); +#else + auto values = load_iq4nl_values_256(); +#endif + int nbl = n / QK_K; + using helper_t = union { __m256i vec; uint32_t val[8]; }; +#ifndef HAVE_FANCY_SIMD + helper_t h, h_shift; +#else + using helper512_t = union { __m512i vec; uint64_t val[8]; }; + helper_t h; + helper512_t h_shift; +#endif + __m256 acc[nrc_y] = {}; + __m256i isum[nrc_y] = {}; + __m256i qx[4]; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto dptr = (const float *)((const char *)vx + (ix+0)*bx); + const block_iq4_ks_r4 * iq4 = (const block_iq4_ks_r4 *)(dptr + 4); + auto d4 = _mm_loadu_ps(dptr); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto scales = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales); + h.vec = _mm256_sub_epi8(_mm256_and_si256(scales, _mm256_set1_epi8(-2)), _mm256_set1_epi8(127)); +#ifndef HAVE_FANCY_SIMD + h_shift.vec = _mm256_slli_epi16(_mm256_and_si256(scales, _mm256_set1_epi8(1)), 2); + { + __m256 v1 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[0])))), + _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[0]))))); + __m256 v2 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[1])))), + _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[1]))))); + __m256 v3 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[2])))), + _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[2]))))); + __m256 v4 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[3])))), + _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[3]))))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto m8 = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums); + acc[iy] = _mm256_fmadd_ps(v1, _mm256_shuffle_ps(m8, m8, 0x00), acc[iy]); + acc[iy] = _mm256_fmadd_ps(v2, _mm256_shuffle_ps(m8, m8, 0x55), acc[iy]); + acc[iy] = _mm256_fmadd_ps(v3, _mm256_shuffle_ps(m8, m8, 0xaa), acc[iy]); + acc[iy] = _mm256_fmadd_ps(v4, _mm256_shuffle_ps(m8, m8, 0xff), acc[iy]); + } + } +#else + auto shift = _mm256_add_epi8(_mm256_set1_epi8(-64), _mm256_slli_epi16(_mm256_and_si256(scales, _mm256_set1_epi8(1)), 1)); + h_shift.vec = _mm512_mullo_epi16(_mm512_cvtepi8_epi16(shift), _mm512_cvtepi8_epi16(h.vec)); +#endif + for (int ib = 0; ib < QK_K/32; ++ib) { +#ifdef HAVE_FANCY_SIMD + auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi32(h.val[ib])); + auto ishifts = _mm256_cvtepi16_epi32(_mm_set1_epi64x(h_shift.val[ib])); + auto scales_m = _mm256_cvtepi32_ps(ishifts); + for (int iy = 0; iy < nrc_y; ++iy) { + float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib]; + acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]); + } +#endif + auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1); + qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4)); + qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4)); + qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4)); + qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4)); +#ifndef HAVE_FANCY_SIMD + auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi32(h.val[ib])), s_shuffle); + auto s1 = _mm256_sign_epi8(qx[0], qx[0]); + auto s2 = _mm256_sign_epi8(qx[1], qx[1]); + auto s3 = _mm256_sign_epi8(qx[2], qx[2]); + auto s4 = _mm256_sign_epi8(qx[3], qx[3]); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi)); +#else + auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])); + auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])); + auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])); + auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4))); +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, ibl)), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm256_setzero_si256(); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + acc[iy] = _mm256_setzero_ps(); + info.store(ix+0, iy, _mm_mul_ps(d4, sum)); + } + } +} + +template +static void mul_mat_iq2_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + int nbl = n / QK_K; +#ifndef HAVE_FANCY_SIMD + auto smask = _mm256_set1_epi64x(0x8040201008040201); + auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000); + auto m4 = _mm256_set1_epi8(4); + auto m1 = _mm256_set1_epi16(1); +#endif + __m256 acc[nrc_y] = {}; + __m256i isum[nrc_y] = {}; + __m256i qx[4]; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto iq2 = (const block_iq2_xxs_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq2[ibl].d)); + auto d4 = _mm256_set_m128(dl, dl); + auto qs = iq2[ibl].qs; + for (int ib = 0; ib < QK_K/32; ++ib) { + qx[0] = _mm256_set_epi64x(iq2xxs_grid[qs[ 3]], iq2xxs_grid[qs[ 2]], iq2xxs_grid[qs[ 1]], iq2xxs_grid[qs[ 0]]); + qx[1] = _mm256_set_epi64x(iq2xxs_grid[qs[ 7]], iq2xxs_grid[qs[ 6]], iq2xxs_grid[qs[ 5]], iq2xxs_grid[qs[ 4]]); + qx[2] = _mm256_set_epi64x(iq2xxs_grid[qs[11]], iq2xxs_grid[qs[10]], iq2xxs_grid[qs[ 9]], iq2xxs_grid[qs[ 8]]); + qx[3] = _mm256_set_epi64x(iq2xxs_grid[qs[15]], iq2xxs_grid[qs[14]], iq2xxs_grid[qs[13]], iq2xxs_grid[qs[12]]); + qs += 16; + auto sas = _mm_loadu_si128((const __m128i *)iq2[ibl].sas + ib); + auto scales = _mm_and_si128(sas, _mm_set1_epi8(1)); +#ifdef HAVE_FANCY_SIMD + scales = _mm_dpbusd_epi32(_mm_set1_epi32(1), scales, _mm_set1_epi32(0x10080402)); +#else + scales = _mm_maddubs_epi16(scales, _mm_set1_epi32(0x10080402)); + scales = _mm_add_epi32(_mm_madd_epi16(_mm_set1_epi16(1), scales), _mm_set1_epi32(1)); +#endif + auto scales32 = MM256_SET_M128I(scales, scales); + auto signs128 = _mm_and_si128(sas, _mm_set1_epi8(-2)); // 0xfe = -2 as signed. Needed to shutup compiler warning. + signs128 = _mm_xor_si128(signs128, _mm_srli_epi16(signs128, 1)); +#ifdef HAVE_FANCY_SIMD + auto mask = (const __mmask32 *)&signs128; + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib); + auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_mask_sub_epi8(y, mask[0], _mm256_setzero_si256(), y)); + auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], _mm256_mask_sub_epi8(y, mask[1], _mm256_setzero_si256(), y)); + auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], _mm256_mask_sub_epi8(y, mask[2], _mm256_setzero_si256(), y)); + auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], _mm256_mask_sub_epi8(y, mask[3], _mm256_setzero_si256(), y)); + auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1 + auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3 + auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3 + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales32, sumi)); + } +#else + auto signs = MM256_SET_M128I(signs128, signs128); + auto shuffle = sign_shuffle; + auto s1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + shuffle = _mm256_add_epi8(shuffle, m4); + auto s2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + shuffle = _mm256_add_epi8(shuffle, m4); + auto s3 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + shuffle = _mm256_add_epi8(shuffle, m4); + auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib); + auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1))); + auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2))); + auto sumi3 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3))); + auto sumi4 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4))); + auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1 + auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3 + auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3 + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales32, sumi)); + } +#endif + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm256_setzero_si256(); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + info.store(ix, iy, _mm_mul_ps(_mm_set1_ps(0.125f), sum)); + acc[iy] = _mm256_setzero_ps(); + } + } +} + +template +static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + int nbl = n / QK_K; +#ifndef HAVE_FANCY_SIMD + auto smask = _mm256_set1_epi64x(0x8040201008040201); + auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000); + auto m4 = _mm256_set1_epi8(4); +#endif + __m256 acc[nrc_y] = {}; +#ifdef HAVE_FANCY_SIMD + __m256i shuffles[2] = { + _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100), + _mm256_set_epi64x(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908) + }; + __m256i isum[2*nrc_y] = {}; +#else + __m256i shuffles[4] = { + MM256_SET_M128I(_mm_set1_epi16(0x0302), _mm_set1_epi16(0x0100)), + MM256_SET_M128I(_mm_set1_epi16(0x0706), _mm_set1_epi16(0x0504)), + MM256_SET_M128I(_mm_set1_epi16(0x0b0a), _mm_set1_epi16(0x0908)), + MM256_SET_M128I(_mm_set1_epi16(0x0f0e), _mm_set1_epi16(0x0d0c)), + }; + __m256i isum[nrc_y == 1 ? 4 : nrc_y] = {}; +#endif + auto s_shuffle = _mm_set_epi64x(0x0f0d0b0907050301, 0x0e0c0a0806040200); + __m256i qx[4]; + union { __m256i vec; uint16_t val[16]; } helper; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto iq2 = (const block_iq2_xs_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq2[ibl].d)); + auto d4 = _mm256_set_m128(dl, dl); + auto s32 = (const uint32_t *)iq2[ibl].scales; + for (int ib = 0; ib < QK_K/32; ++ib) { + auto val = _mm256_loadu_si256((const __m256i *)iq2[ibl].qs + ib); + helper.vec = _mm256_and_si256(val, _mm256_set1_epi16(511)); + qx[0] = _mm256_set_epi64x(iq2xs_grid[helper.val[ 3]], iq2xs_grid[helper.val[ 2]], iq2xs_grid[helper.val[ 1]], iq2xs_grid[helper.val[ 0]]); + qx[1] = _mm256_set_epi64x(iq2xs_grid[helper.val[ 7]], iq2xs_grid[helper.val[ 6]], iq2xs_grid[helper.val[ 5]], iq2xs_grid[helper.val[ 4]]); + qx[2] = _mm256_set_epi64x(iq2xs_grid[helper.val[11]], iq2xs_grid[helper.val[10]], iq2xs_grid[helper.val[ 9]], iq2xs_grid[helper.val[ 8]]); + qx[3] = _mm256_set_epi64x(iq2xs_grid[helper.val[15]], iq2xs_grid[helper.val[14]], iq2xs_grid[helper.val[13]], iq2xs_grid[helper.val[12]]); + auto signs16 = _mm256_srli_epi16(val, 9); + signs16 = _mm256_xor_si256(signs16, _mm256_slli_epi16(signs16, 1)); + auto signs128 = _mm_or_si128(_mm256_castsi256_si128(signs16), _mm_slli_epi16(_mm256_extracti128_si256(signs16, 1), 8)); + signs128 = _mm_shuffle_epi8(signs128, s_shuffle); + auto scales = _mm_set1_epi32(s32[ib]); + scales = _mm_and_si128(_mm_unpacklo_epi8(scales, _mm_srli_epi16(scales, 4)), _mm_set1_epi8(0xf)); + scales = _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi8(1)); + auto scales16 = _mm256_cvtepi8_epi16(scales); // 0...7, 0...7 +#ifdef HAVE_FANCY_SIMD + __m256i scs[2] = { _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]) }; + auto mask = (const __mmask32 *)&signs128; + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib); + auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_mask_sub_epi8(y, mask[0], _mm256_setzero_si256(), y)); // blocks: 0,0,0,0, 1,1,1,1, row 0 + auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], _mm256_mask_sub_epi8(y, mask[1], _mm256_setzero_si256(), y)); // blocks: 2,2,2,2, 3,3,3,3, row 1 + auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], _mm256_mask_sub_epi8(y, mask[2], _mm256_setzero_si256(), y)); // blocks: 4,4,4,4, 5,5,5,5, row 2 + auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], _mm256_mask_sub_epi8(y, mask[3], _mm256_setzero_si256(), y)); // blocks: 6,6,6,6, 7,7,7,7, row 3 + auto s12 = _mm256_packs_epi32(sumi1, sumi2); // 0,0,0,0, 2,2,2,2, 1,1,1,1, 3,3,3,3 + auto s34 = _mm256_packs_epi32(sumi3, sumi4); // 4,4,4,4, 6,6,6,6, 5,5,5,5, 7,7,7,7 + isum[2*iy+0] = _mm256_add_epi32(isum[2*iy+0], _mm256_madd_epi16(scs[0], s12)); + isum[2*iy+1] = _mm256_add_epi32(isum[2*iy+1], _mm256_madd_epi16(scs[1], s34)); + } +#else + auto signs = MM256_SET_M128I(signs128, signs128); + auto shuffle = sign_shuffle; + auto s1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + shuffle = _mm256_add_epi8(shuffle, m4); + auto s2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + shuffle = _mm256_add_epi8(shuffle, m4); + auto s3 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + shuffle = _mm256_add_epi8(shuffle, m4); + auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + __m256i scs[4] = { + _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]), + _mm256_shuffle_epi8(scales16, shuffles[2]), _mm256_shuffle_epi8(scales16, shuffles[3]), + }; + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib); + if constexpr (nrc_y == 1) { + isum[0] = _mm256_add_epi32(isum[0], _mm256_madd_epi16(scs[0], _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1)))); + isum[1] = _mm256_add_epi32(isum[1], _mm256_madd_epi16(scs[1], _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2)))); + isum[2] = _mm256_add_epi32(isum[2], _mm256_madd_epi16(scs[2], _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3)))); + isum[3] = _mm256_add_epi32(isum[3], _mm256_madd_epi16(scs[3], _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4)))); + } else { + auto sumi1 = _mm256_madd_epi16(scs[0], _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1))); // blocks 4x0, 4x1, row 0 + auto sumi2 = _mm256_madd_epi16(scs[1], _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2))); // blocks 4x2, 4x3, row 1 + auto sumi3 = _mm256_madd_epi16(scs[2], _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3))); // blocks 4x4, 4x5, row 2 + auto sumi4 = _mm256_madd_epi16(scs[3], _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4))); // blocks 4x6, 4x7, row 3 + auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1 + auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3 + auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3 + isum[iy] = _mm256_add_epi32(isum[iy], sumi); + } + } +#endif + } + for (int iy = 0; iy < nrc_y; ++iy) { +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_hadd_epi32(isum[2*iy+0], isum[2*iy+1]); + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]); + isum[2*iy+0] = isum[2*iy+1] = _mm256_setzero_si256(); +#else + if constexpr (nrc_y == 1) { + auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(isum[0], isum[1]), _mm256_unpackhi_epi32(isum[0], isum[1])); + auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(isum[2], isum[3]), _mm256_unpackhi_epi32(isum[2], isum[3])); + auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]); + isum[0] = isum[1] = isum[2] = isum[3] = _mm256_setzero_si256(); + } else { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm256_setzero_si256(); + } +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + info.store(ix, iy, _mm_mul_ps(_mm_set1_ps(0.125f), sum)); + acc[iy] = _mm256_setzero_ps(); + } + } +} + +static void mul_mat_iq2_xs_r4_q8_k_16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + constexpr int nrc_y = 16; + Q8 q8(info); + int nbl = n / QK_K; +#ifndef HAVE_FANCY_SIMD + auto smask = _mm256_set1_epi64x(0x8040201008040201); + auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000); + auto m4 = _mm256_set1_epi8(4); +#endif + __m256 acc[nrc_y] = {}; +#ifdef HAVE_FANCY_SIMD + __m256i shuffles[2] = { + _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100), + _mm256_set_epi64x(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908) + }; + __m256i isum[2*nrc_y] = {}; +#else + __m256i shuffles[4] = { + MM256_SET_M128I(_mm_set1_epi16(0x0302), _mm_set1_epi16(0x0100)), + MM256_SET_M128I(_mm_set1_epi16(0x0706), _mm_set1_epi16(0x0504)), + MM256_SET_M128I(_mm_set1_epi16(0x0b0a), _mm_set1_epi16(0x0908)), + MM256_SET_M128I(_mm_set1_epi16(0x0f0e), _mm_set1_epi16(0x0d0c)), + }; + __m256i isum[nrc_y == 1 ? 4 : nrc_y] = {}; +#endif + auto s_shuffle = _mm_set_epi64x(0x0f0d0b0907050301, 0x0e0c0a0806040200); + __m256i qx[4]; + union { __m256i vec; uint16_t val[16]; } helper; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto iq2 = (const block_iq2_xs_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq2[ibl].d)); + auto d4 = _mm256_set_m128(dl, dl); + auto s32 = (const uint32_t *)iq2[ibl].scales; + { + auto scale_bits = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales); + auto scales1 = _mm256_and_si256(scale_bits, _mm256_set1_epi8(0xf)); + auto scales2 = _mm256_and_si256(_mm256_srli_epi16(scale_bits, 4), _mm256_set1_epi8(0xf)); + scales1 = _mm256_or_si256(_mm256_slli_epi16(scales1, 1), _mm256_set1_epi8(1)); + scales2 = _mm256_or_si256(_mm256_slli_epi16(scales2, 1), _mm256_set1_epi8(1)); + auto s1_8 = _mm256_unpacklo_epi8(scales1, scales2); // blocks 0...15, 32...47 (0...3, 8...11 from each row) + auto s2_8 = _mm256_unpackhi_epi8(scales1, scales2); // blocks 16..31, 48...63 (4...7, 12..15 from each row) + auto s1_16 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(s1_8)); // 0...15 (0...3 from each row) + auto s2_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(s1_8, 1)); // 32...47 (8..11 from each row) + auto s3_16 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(s2_8)); // 16...31 (4...7 from each row) + auto s4_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(s2_8, 1)); // 48...63 (12.15 from each row) + auto t1 = MM256_SET_M128I(_mm256_castsi256_si128(s2_16), _mm256_castsi256_si128(s1_16)); // 0,1 and 8,9 from each row + auto t2 = MM256_SET_M128I(_mm256_extracti128_si256(s2_16, 1), _mm256_extracti128_si256(s1_16, 1)); // 2,3 and 10,11 from each row + auto t3 = MM256_SET_M128I(_mm256_castsi256_si128(s4_16), _mm256_castsi256_si128(s3_16)); // 4,5 and 12,13 from each row + auto t4 = MM256_SET_M128I(_mm256_extracti128_si256(s4_16, 1), _mm256_extracti128_si256(s3_16, 1)); // 6,7 and 14,15 from each row + for (int iy = 0; iy < nrc_y; ++iy) { + auto bsums = q8.load_bsums(iy, ibl); + auto sumi = _mm256_setzero_si256(); +#ifdef HAVE_FANCY_SIMD + sumi = _mm256_dpwssd_epi32(sumi, t1, _mm256_shuffle_epi32(bsums, 0x00)); + sumi = _mm256_dpwssd_epi32(sumi, t2, _mm256_shuffle_epi32(bsums, 0x55)); + sumi = _mm256_dpwssd_epi32(sumi, t3, _mm256_shuffle_epi32(bsums, 0xaa)); + sumi = _mm256_dpwssd_epi32(sumi, t4, _mm256_shuffle_epi32(bsums, 0xff)); +#else + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t1, _mm256_shuffle_epi32(bsums, 0x00))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t2, _mm256_shuffle_epi32(bsums, 0x55))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t3, _mm256_shuffle_epi32(bsums, 0xaa))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t4, _mm256_shuffle_epi32(bsums, 0xff))); +#endif + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(-64.f*q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]); + } + } + for (int ib = 0; ib < QK_K/32; ++ib) { + auto val = _mm256_loadu_si256((const __m256i *)iq2[ibl].qs + ib); + helper.vec = _mm256_and_si256(val, _mm256_set1_epi16(511)); + qx[0] = _mm256_set_epi64x(iq2xs_grid[helper.val[ 3]], iq2xs_grid[helper.val[ 2]], iq2xs_grid[helper.val[ 1]], iq2xs_grid[helper.val[ 0]]); + qx[1] = _mm256_set_epi64x(iq2xs_grid[helper.val[ 7]], iq2xs_grid[helper.val[ 6]], iq2xs_grid[helper.val[ 5]], iq2xs_grid[helper.val[ 4]]); + qx[2] = _mm256_set_epi64x(iq2xs_grid[helper.val[11]], iq2xs_grid[helper.val[10]], iq2xs_grid[helper.val[ 9]], iq2xs_grid[helper.val[ 8]]); + qx[3] = _mm256_set_epi64x(iq2xs_grid[helper.val[15]], iq2xs_grid[helper.val[14]], iq2xs_grid[helper.val[13]], iq2xs_grid[helper.val[12]]); + auto signs16 = _mm256_srli_epi16(val, 9); + signs16 = _mm256_xor_si256(signs16, _mm256_slli_epi16(signs16, 1)); + auto signs128 = _mm_or_si128(_mm256_castsi256_si128(signs16), _mm_slli_epi16(_mm256_extracti128_si256(signs16, 1), 8)); + signs128 = _mm_shuffle_epi8(signs128, s_shuffle); + auto scales = _mm_set1_epi32(s32[ib]); + scales = _mm_and_si128(_mm_unpacklo_epi8(scales, _mm_srli_epi16(scales, 4)), _mm_set1_epi8(0xf)); + scales = _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi8(1)); + auto scales16 = _mm256_cvtepi8_epi16(scales); // 0...7, 0...7 +#ifdef HAVE_FANCY_SIMD + __m256i scs[2] = { _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]) }; + auto mask = (const __mmask32 *)&signs128; + qx[0] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[0], mask[0], _mm256_setzero_si256(), qx[0])); + qx[1] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[1], mask[1], _mm256_setzero_si256(), qx[1])); + qx[2] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[2], mask[2], _mm256_setzero_si256(), qx[2])); + qx[3] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[3], mask[3], _mm256_setzero_si256(), qx[3])); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib); + auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], y); // blocks: 0,0,0,0, 1,1,1,1, row 0 + auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], y); // blocks: 2,2,2,2, 3,3,3,3, row 1 + auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], y); // blocks: 4,4,4,4, 5,5,5,5, row 2 + auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], y); // blocks: 6,6,6,6, 7,7,7,7, row 3 + auto s12 = _mm256_packs_epi32(sumi1, sumi2); // 0,0,0,0, 2,2,2,2, 1,1,1,1, 3,3,3,3 + auto s34 = _mm256_packs_epi32(sumi3, sumi4); // 4,4,4,4, 6,6,6,6, 5,5,5,5, 7,7,7,7 + isum[2*iy+0] = _mm256_add_epi32(isum[2*iy+0], _mm256_madd_epi16(scs[0], s12)); + isum[2*iy+1] = _mm256_add_epi32(isum[2*iy+1], _mm256_madd_epi16(scs[1], s34)); + } +#else + auto signs = MM256_SET_M128I(signs128, signs128); + auto shuffle = sign_shuffle; + auto s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + shuffle = _mm256_add_epi8(shuffle, m4); + qx[0] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[0], s)); + s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + shuffle = _mm256_add_epi8(shuffle, m4); + qx[1] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[1], s)); + s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + shuffle = _mm256_add_epi8(shuffle, m4); + qx[2] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[2], s)); + s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + qx[3] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[3], s)); + __m256i scs[4] = { + _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]), + _mm256_shuffle_epi8(scales16, shuffles[2]), _mm256_shuffle_epi8(scales16, shuffles[3]), + }; + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib); + auto sumi1 = _mm256_madd_epi16(scs[0], _mm256_maddubs_epi16(qx[0], y)); // blocks 4x0, 4x1, row 0 + auto sumi2 = _mm256_madd_epi16(scs[1], _mm256_maddubs_epi16(qx[1], y)); // blocks 4x2, 4x3, row 1 + auto sumi3 = _mm256_madd_epi16(scs[2], _mm256_maddubs_epi16(qx[2], y)); // blocks 4x4, 4x5, row 2 + auto sumi4 = _mm256_madd_epi16(scs[3], _mm256_maddubs_epi16(qx[3], y)); // blocks 4x6, 4x7, row 3 + auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1 + auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3 + auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3 + isum[iy] = _mm256_add_epi32(isum[iy], sumi); + } +#endif + } + for (int iy = 0; iy < nrc_y; ++iy) { +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_hadd_epi32(isum[2*iy+0], isum[2*iy+1]); + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]); + isum[2*iy+0] = isum[2*iy+1] = _mm256_setzero_si256(); +#else + if constexpr (nrc_y == 1) { + auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(isum[0], isum[1]), _mm256_unpackhi_epi32(isum[0], isum[1])); + auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(isum[2], isum[3]), _mm256_unpackhi_epi32(isum[2], isum[3])); + auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]); + isum[0] = isum[1] = isum[2] = isum[3] = _mm256_setzero_si256(); + } else { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm256_setzero_si256(); + } +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + info.store(ix, iy, _mm_mul_ps(_mm_set1_ps(0.125f), sum)); + acc[iy] = _mm256_setzero_ps(); + } + } +} + +template +static void mul_mat_iq2_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + int nbl = n / QK_K; +#ifndef HAVE_FANCY_SIMD + auto smask = _mm256_set1_epi64x(0x8040201008040201); + auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000); + auto m4 = _mm256_set1_epi8(4); +#endif + __m256 acc[nrc_y] = {}; +#ifdef HAVE_FANCY_SIMD + __m256i shuffles[2] = { + _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100), + _mm256_set_epi64x(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908) + }; + __m256i isum[2*nrc_y] = {}; +#else + __m256i shuffles[4] = { + MM256_SET_M128I(_mm_set1_epi16(0x0302), _mm_set1_epi16(0x0100)), + MM256_SET_M128I(_mm_set1_epi16(0x0706), _mm_set1_epi16(0x0504)), + MM256_SET_M128I(_mm_set1_epi16(0x0b0a), _mm_set1_epi16(0x0908)), + MM256_SET_M128I(_mm_set1_epi16(0x0f0e), _mm_set1_epi16(0x0d0c)), + }; + __m256i isum[nrc_y == 1 ? 4 : nrc_y] = {}; +#endif + __m256i qx[4]; + auto grid = iq2s_grid; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto iq2 = (const block_iq2_s_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq2[ibl].d)); + auto d4 = _mm256_set_m128(dl, dl); + auto s32 = (const uint32_t *)iq2[ibl].scales; + auto ql = iq2[ibl].qs; + auto qh = iq2[ibl].qh; + for (int ib = 0; ib < QK_K/32; ++ib) { + qx[0] = _mm256_set_epi64x(grid[ql[ 3] | ((qh[0] << 2) & 0x300)], grid[ql[ 2] | ((qh[0] << 4) & 0x300)], grid[ql[ 1] | ((qh[0] << 6) & 0x300)], grid[ql[ 0] | ((qh[0] << 8) & 0x300)]); + qx[1] = _mm256_set_epi64x(grid[ql[ 7] | ((qh[1] << 2) & 0x300)], grid[ql[ 6] | ((qh[1] << 4) & 0x300)], grid[ql[ 5] | ((qh[1] << 6) & 0x300)], grid[ql[ 4] | ((qh[1] << 8) & 0x300)]); + qx[2] = _mm256_set_epi64x(grid[ql[11] | ((qh[2] << 2) & 0x300)], grid[ql[10] | ((qh[2] << 4) & 0x300)], grid[ql[ 9] | ((qh[2] << 6) & 0x300)], grid[ql[ 8] | ((qh[2] << 8) & 0x300)]); + qx[3] = _mm256_set_epi64x(grid[ql[15] | ((qh[3] << 2) & 0x300)], grid[ql[14] | ((qh[3] << 4) & 0x300)], grid[ql[13] | ((qh[3] << 6) & 0x300)], grid[ql[12] | ((qh[3] << 8) & 0x300)]); + ql += 16; qh += 4; + auto signs128 = _mm_loadu_si128((const __m128i*)iq2[ibl].signs + ib); + auto scales = _mm_set1_epi32(s32[ib]); + scales = _mm_and_si128(_mm_unpacklo_epi8(scales, _mm_srli_epi16(scales, 4)), _mm_set1_epi8(0xf)); + scales = _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi8(1)); + auto scales16 = _mm256_cvtepi8_epi16(scales); // 0...7, 0...7 +#ifdef HAVE_FANCY_SIMD + __m256i scs[2] = { _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]) }; + auto mask = (const __mmask32 *)&signs128; + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib); + auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_mask_sub_epi8(y, mask[0], _mm256_setzero_si256(), y)); // blocks: 0,0,0,0, 1,1,1,1, row 0 + auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], _mm256_mask_sub_epi8(y, mask[1], _mm256_setzero_si256(), y)); // blocks: 2,2,2,2, 3,3,3,3, row 1 + auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], _mm256_mask_sub_epi8(y, mask[2], _mm256_setzero_si256(), y)); // blocks: 4,4,4,4, 5,5,5,5, row 2 + auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], _mm256_mask_sub_epi8(y, mask[3], _mm256_setzero_si256(), y)); // blocks: 6,6,6,6, 7,7,7,7, row 3 + auto s12 = _mm256_packs_epi32(sumi1, sumi2); // 0,0,0,0, 2,2,2,2, 1,1,1,1, 3,3,3,3 + auto s34 = _mm256_packs_epi32(sumi3, sumi4); // 4,4,4,4, 6,6,6,6, 5,5,5,5, 7,7,7,7 + isum[2*iy+0] = _mm256_add_epi32(isum[2*iy+0], _mm256_madd_epi16(scs[0], s12)); + isum[2*iy+1] = _mm256_add_epi32(isum[2*iy+1], _mm256_madd_epi16(scs[1], s34)); + } +#else + auto signs = MM256_SET_M128I(signs128, signs128); + auto shuffle = sign_shuffle; + auto s1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + shuffle = _mm256_add_epi8(shuffle, m4); + auto s2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + shuffle = _mm256_add_epi8(shuffle, m4); + auto s3 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + shuffle = _mm256_add_epi8(shuffle, m4); + auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + __m256i scs[4] = { + _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]), + _mm256_shuffle_epi8(scales16, shuffles[2]), _mm256_shuffle_epi8(scales16, shuffles[3]), + }; + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib); + if constexpr (nrc_y == 1) { + isum[0] = _mm256_add_epi32(isum[0], _mm256_madd_epi16(scs[0], _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1)))); + isum[1] = _mm256_add_epi32(isum[1], _mm256_madd_epi16(scs[1], _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2)))); + isum[2] = _mm256_add_epi32(isum[2], _mm256_madd_epi16(scs[2], _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3)))); + isum[3] = _mm256_add_epi32(isum[3], _mm256_madd_epi16(scs[3], _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4)))); + } else { + auto sumi1 = _mm256_madd_epi16(scs[0], _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1))); // blocks 4x0, 4x1, row 0 + auto sumi2 = _mm256_madd_epi16(scs[1], _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2))); // blocks 4x2, 4x3, row 1 + auto sumi3 = _mm256_madd_epi16(scs[2], _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3))); // blocks 4x4, 4x5, row 2 + auto sumi4 = _mm256_madd_epi16(scs[3], _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4))); // blocks 4x6, 4x7, row 3 + auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1 + auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3 + auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3 + isum[iy] = _mm256_add_epi32(isum[iy], sumi); + } + } +#endif + } + for (int iy = 0; iy < nrc_y; ++iy) { +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_hadd_epi32(isum[2*iy+0], isum[2*iy+1]); + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]); + isum[2*iy+0] = isum[2*iy+1] = _mm256_setzero_si256(); +#else + if constexpr (nrc_y == 1) { + auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(isum[0], isum[1]), _mm256_unpackhi_epi32(isum[0], isum[1])); + auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(isum[2], isum[3]), _mm256_unpackhi_epi32(isum[2], isum[3])); + auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]); + isum[0] = isum[1] = isum[2] = isum[3] = _mm256_setzero_si256(); + } else { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm256_setzero_si256(); + } +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + info.store(ix, iy, _mm_mul_ps(_mm_set1_ps(0.125f), sum)); + acc[iy] = _mm256_setzero_ps(); + } + } +} + +static void mul_mat_iq2_s_r4_q8_k_16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + constexpr int nrc_y = 16; + Q8 q8(info); + int nbl = n / QK_K; +#ifndef HAVE_FANCY_SIMD + auto smask = _mm256_set1_epi64x(0x8040201008040201); + auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000); + auto m4 = _mm256_set1_epi8(4); +#endif + __m256 acc[nrc_y] = {}; +#ifdef HAVE_FANCY_SIMD + __m256i shuffles[2] = { + _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100), + _mm256_set_epi64x(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908) + }; + __m256i isum[2*nrc_y] = {}; +#else + __m256i shuffles[4] = { + MM256_SET_M128I(_mm_set1_epi16(0x0302), _mm_set1_epi16(0x0100)), + MM256_SET_M128I(_mm_set1_epi16(0x0706), _mm_set1_epi16(0x0504)), + MM256_SET_M128I(_mm_set1_epi16(0x0b0a), _mm_set1_epi16(0x0908)), + MM256_SET_M128I(_mm_set1_epi16(0x0f0e), _mm_set1_epi16(0x0d0c)), + }; + __m256i isum[nrc_y == 1 ? 4 : nrc_y] = {}; +#endif + __m256i qx[4]; + auto grid = iq2s_grid; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto iq2 = (const block_iq2_s_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq2[ibl].d)); + auto d4 = _mm256_set_m128(dl, dl); + auto s32 = (const uint32_t *)iq2[ibl].scales; + auto ql = iq2[ibl].qs; + auto qh = iq2[ibl].qh; + { + auto scale_bits = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales); + auto scales1 = _mm256_and_si256(scale_bits, _mm256_set1_epi8(0xf)); + auto scales2 = _mm256_and_si256(_mm256_srli_epi16(scale_bits, 4), _mm256_set1_epi8(0xf)); + scales1 = _mm256_or_si256(_mm256_slli_epi16(scales1, 1), _mm256_set1_epi8(1)); + scales2 = _mm256_or_si256(_mm256_slli_epi16(scales2, 1), _mm256_set1_epi8(1)); + auto s1_8 = _mm256_unpacklo_epi8(scales1, scales2); // blocks 0...15, 32...47 (0...3, 8...11 from each row) + auto s2_8 = _mm256_unpackhi_epi8(scales1, scales2); // blocks 16..31, 48...63 (4...7, 12..15 from each row) + auto s1_16 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(s1_8)); // 0...15 (0...3 from each row) + auto s2_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(s1_8, 1)); // 32...47 (8..11 from each row) + auto s3_16 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(s2_8)); // 16...31 (4...7 from each row) + auto s4_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(s2_8, 1)); // 48...63 (12.15 from each row) + auto t1 = MM256_SET_M128I(_mm256_castsi256_si128(s2_16), _mm256_castsi256_si128(s1_16)); // 0,1 and 8,9 from each row + auto t2 = MM256_SET_M128I(_mm256_extracti128_si256(s2_16, 1), _mm256_extracti128_si256(s1_16, 1)); // 2,3 and 10,11 from each row + auto t3 = MM256_SET_M128I(_mm256_castsi256_si128(s4_16), _mm256_castsi256_si128(s3_16)); // 4,5 and 12,13 from each row + auto t4 = MM256_SET_M128I(_mm256_extracti128_si256(s4_16, 1), _mm256_extracti128_si256(s3_16, 1)); // 6,7 and 14,15 from each row + for (int iy = 0; iy < nrc_y; ++iy) { + auto bsums = q8.load_bsums(iy, ibl); + auto sumi = _mm256_setzero_si256(); +#ifdef HAVE_FANCY_SIMD + sumi = _mm256_dpwssd_epi32(sumi, t1, _mm256_shuffle_epi32(bsums, 0x00)); + sumi = _mm256_dpwssd_epi32(sumi, t2, _mm256_shuffle_epi32(bsums, 0x55)); + sumi = _mm256_dpwssd_epi32(sumi, t3, _mm256_shuffle_epi32(bsums, 0xaa)); + sumi = _mm256_dpwssd_epi32(sumi, t4, _mm256_shuffle_epi32(bsums, 0xff)); +#else + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t1, _mm256_shuffle_epi32(bsums, 0x00))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t2, _mm256_shuffle_epi32(bsums, 0x55))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t3, _mm256_shuffle_epi32(bsums, 0xaa))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(t4, _mm256_shuffle_epi32(bsums, 0xff))); +#endif + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(-64.f*q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]); + } + } + for (int ib = 0; ib < QK_K/32; ++ib) { + qx[0] = _mm256_set_epi64x(grid[ql[ 3] | ((qh[0] << 2) & 0x300)], grid[ql[ 2] | ((qh[0] << 4) & 0x300)], grid[ql[ 1] | ((qh[0] << 6) & 0x300)], grid[ql[ 0] | ((qh[0] << 8) & 0x300)]); + qx[1] = _mm256_set_epi64x(grid[ql[ 7] | ((qh[1] << 2) & 0x300)], grid[ql[ 6] | ((qh[1] << 4) & 0x300)], grid[ql[ 5] | ((qh[1] << 6) & 0x300)], grid[ql[ 4] | ((qh[1] << 8) & 0x300)]); + qx[2] = _mm256_set_epi64x(grid[ql[11] | ((qh[2] << 2) & 0x300)], grid[ql[10] | ((qh[2] << 4) & 0x300)], grid[ql[ 9] | ((qh[2] << 6) & 0x300)], grid[ql[ 8] | ((qh[2] << 8) & 0x300)]); + qx[3] = _mm256_set_epi64x(grid[ql[15] | ((qh[3] << 2) & 0x300)], grid[ql[14] | ((qh[3] << 4) & 0x300)], grid[ql[13] | ((qh[3] << 6) & 0x300)], grid[ql[12] | ((qh[3] << 8) & 0x300)]); + ql += 16; qh += 4; + auto signs128 = _mm_loadu_si128((const __m128i*)iq2[ibl].signs + ib); + auto scales = _mm_set1_epi32(s32[ib]); + scales = _mm_and_si128(_mm_unpacklo_epi8(scales, _mm_srli_epi16(scales, 4)), _mm_set1_epi8(0xf)); + scales = _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi8(1)); + auto scales16 = _mm256_cvtepi8_epi16(scales); // 0...7, 0...7 +#ifdef HAVE_FANCY_SIMD + __m256i scs[2] = { _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]) }; + auto mask = (const __mmask32 *)&signs128; + qx[0] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[0], mask[0], _mm256_setzero_si256(), qx[0])); + qx[1] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[1], mask[1], _mm256_setzero_si256(), qx[1])); + qx[2] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[2], mask[2], _mm256_setzero_si256(), qx[2])); + qx[3] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_mask_sub_epi8(qx[3], mask[3], _mm256_setzero_si256(), qx[3])); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib); + auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], y); // blocks: 0,0,0,0, 1,1,1,1, row 0 + auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], y); // blocks: 2,2,2,2, 3,3,3,3, row 1 + auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], y); // blocks: 4,4,4,4, 5,5,5,5, row 2 + auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], y); // blocks: 6,6,6,6, 7,7,7,7, row 3 + auto s12 = _mm256_packs_epi32(sumi1, sumi2); // 0,0,0,0, 2,2,2,2, 1,1,1,1, 3,3,3,3 + auto s34 = _mm256_packs_epi32(sumi3, sumi4); // 4,4,4,4, 6,6,6,6, 5,5,5,5, 7,7,7,7 + isum[2*iy+0] = _mm256_add_epi32(isum[2*iy+0], _mm256_madd_epi16(scs[0], s12)); + isum[2*iy+1] = _mm256_add_epi32(isum[2*iy+1], _mm256_madd_epi16(scs[1], s34)); + } +#else + auto signs = MM256_SET_M128I(signs128, signs128); + auto shuffle = sign_shuffle; + auto s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + shuffle = _mm256_add_epi8(shuffle, m4); + qx[0] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[0], s)); + s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + shuffle = _mm256_add_epi8(shuffle, m4); + qx[1] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[1], s)); + s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + shuffle = _mm256_add_epi8(shuffle, m4); + qx[2] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[2], s)); + s = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + qx[3] = _mm256_add_epi8(_mm256_set1_epi8(64), _mm256_sign_epi8(qx[3], s)); + __m256i scs[4] = { + _mm256_shuffle_epi8(scales16, shuffles[0]), _mm256_shuffle_epi8(scales16, shuffles[1]), + _mm256_shuffle_epi8(scales16, shuffles[2]), _mm256_shuffle_epi8(scales16, shuffles[3]), + }; + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib); + auto sumi1 = _mm256_madd_epi16(scs[0], _mm256_maddubs_epi16(qx[0], y)); // blocks 4x0, 4x1, row 0 + auto sumi2 = _mm256_madd_epi16(scs[1], _mm256_maddubs_epi16(qx[1], y)); // blocks 4x2, 4x3, row 1 + auto sumi3 = _mm256_madd_epi16(scs[2], _mm256_maddubs_epi16(qx[2], y)); // blocks 4x4, 4x5, row 2 + auto sumi4 = _mm256_madd_epi16(scs[3], _mm256_maddubs_epi16(qx[3], y)); // blocks 4x6, 4x7, row 3 + auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1 + auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3 + auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3 + isum[iy] = _mm256_add_epi32(isum[iy], sumi); + } +#endif + } + for (int iy = 0; iy < nrc_y; ++iy) { +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_hadd_epi32(isum[2*iy+0], isum[2*iy+1]); + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]); + isum[2*iy+0] = isum[2*iy+1] = _mm256_setzero_si256(); +#else + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm256_setzero_si256(); +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + info.store(ix, iy, _mm_mul_ps(_mm_set1_ps(0.125f), sum)); + acc[iy] = _mm256_setzero_ps(); + } + } +} + +template +static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + int nbl = n / QK_K; +#ifndef HAVE_FANCY_SIMD + auto smask = _mm256_set1_epi64x(0x8040201008040201); + auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000); + auto m4 = _mm256_set1_epi8(4); + auto m1 = _mm256_set1_epi16(1); +#endif + __m256 acc[nrc_y] = {}; + __m256i isum[nrc_y] = {}; + __m256i qx[4]; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto iq3 = (const block_iq3_xxs_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dl = _mm_mul_ps(_mm_set1_ps(0.25f), _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq3[ibl].d))); // TODO: absorb the 0.25 factor into d when quantizing/repacking + auto d4 = _mm256_set_m128(dl, dl); + for (int ib = 0; ib < QK_K/32; ++ib) { + qx[0] = _mm256_set_epi32(iq3xxs_grid[iq3[ibl].qs[32*ib+ 7]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 6]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 5]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 4]], + iq3xxs_grid[iq3[ibl].qs[32*ib+ 3]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 2]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 1]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 0]]); + qx[1] = _mm256_set_epi32(iq3xxs_grid[iq3[ibl].qs[32*ib+15]], iq3xxs_grid[iq3[ibl].qs[32*ib+14]], iq3xxs_grid[iq3[ibl].qs[32*ib+13]], iq3xxs_grid[iq3[ibl].qs[32*ib+12]], + iq3xxs_grid[iq3[ibl].qs[32*ib+11]], iq3xxs_grid[iq3[ibl].qs[32*ib+10]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 9]], iq3xxs_grid[iq3[ibl].qs[32*ib+ 8]]); + qx[2] = _mm256_set_epi32(iq3xxs_grid[iq3[ibl].qs[32*ib+23]], iq3xxs_grid[iq3[ibl].qs[32*ib+22]], iq3xxs_grid[iq3[ibl].qs[32*ib+21]], iq3xxs_grid[iq3[ibl].qs[32*ib+20]], + iq3xxs_grid[iq3[ibl].qs[32*ib+19]], iq3xxs_grid[iq3[ibl].qs[32*ib+18]], iq3xxs_grid[iq3[ibl].qs[32*ib+17]], iq3xxs_grid[iq3[ibl].qs[32*ib+16]]); + qx[3] = _mm256_set_epi32(iq3xxs_grid[iq3[ibl].qs[32*ib+31]], iq3xxs_grid[iq3[ibl].qs[32*ib+30]], iq3xxs_grid[iq3[ibl].qs[32*ib+29]], iq3xxs_grid[iq3[ibl].qs[32*ib+28]], + iq3xxs_grid[iq3[ibl].qs[32*ib+27]], iq3xxs_grid[iq3[ibl].qs[32*ib+26]], iq3xxs_grid[iq3[ibl].qs[32*ib+25]], iq3xxs_grid[iq3[ibl].qs[32*ib+24]]); + auto sas = _mm_loadu_si128((const __m128i *)iq3[ibl].sas + ib); + auto scales = _mm_and_si128(sas, _mm_set1_epi8(1)); +#ifdef HAVE_FANCY_SIMD + scales = _mm_dpbusd_epi32(_mm_set1_epi32(1), scales, _mm_set1_epi32(0x10080402)); +#else + scales = _mm_maddubs_epi16(scales, _mm_set1_epi32(0x10080402)); + scales = _mm_add_epi32(_mm_madd_epi16(_mm_set1_epi16(1), scales), _mm_set1_epi32(1)); + //auto t1 = _mm_or_si128(_mm_and_si128(scales, _mm_set1_epi32(0x00000001)), _mm_srli_epi32(_mm_and_si128(scales, _mm_set1_epi32(0x00000100)), 7)); + //auto t2 = _mm_or_si128(_mm_srli_epi32(_mm_and_si128(scales, _mm_set1_epi32(0x00010000)), 14), _mm_srli_epi32(_mm_and_si128(scales, _mm_set1_epi32(0x01000000)), 21)); + //scales = _mm_or_si128(_mm_slli_epi32(_mm_or_si128(t1, t2), 1), _mm_set1_epi32(1)); +#endif + auto scales32 = MM256_SET_M128I(scales, scales); + auto signs128 = _mm_and_si128(sas, _mm_set1_epi8(-2)); // 0xfe = -2 as signed. Needed to shutup compiler warning. + signs128 = _mm_xor_si128(signs128, _mm_srli_epi16(signs128, 1)); +#ifdef HAVE_FANCY_SIMD + auto mask = (const __mmask32 *)&signs128; + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib); + auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_mask_sub_epi8(y, mask[0], _mm256_setzero_si256(), y)); + auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[1], _mm256_mask_sub_epi8(y, mask[1], _mm256_setzero_si256(), y)); + auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2], _mm256_mask_sub_epi8(y, mask[2], _mm256_setzero_si256(), y)); + auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[3], _mm256_mask_sub_epi8(y, mask[3], _mm256_setzero_si256(), y)); + auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1 + auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3 + auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3 + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales32, sumi)); + } +#else + auto signs = MM256_SET_M128I(signs128, signs128); + auto shuffle = sign_shuffle; + auto s1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + shuffle = _mm256_add_epi8(shuffle, m4); + auto s2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + shuffle = _mm256_add_epi8(shuffle, m4); + auto s3 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + shuffle = _mm256_add_epi8(shuffle, m4); + auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(signs, shuffle), smask), smask), _mm256_set1_epi8(1)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib); + auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1))); + auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2))); + auto sumi3 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3))); + auto sumi4 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4))); + auto s12 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,1, 0,1, 0,1, 0,1 + auto s34 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,3, 2,3, 2,3, 2,3 + auto sumi = _mm256_add_epi32(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,1,2,3, 0,1,2,3 + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales32, sumi)); + } +#endif + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm256_setzero_si256(); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + info.store(ix, iy, sum); + acc[iy] = _mm256_setzero_ps(); + } + } +} + +#ifdef HAVE_FANCY_SIMD +// Strangely enough, the following implementation makes PP ~6% slower and TG ~6% faster +// compared to the vanilla AVX2 version below. +struct IndexHelperIQ3S { + union index_t { + __m256i vec; + uint16_t val[16]; + }; + inline void make2(const uint8_t * qs, const uint8_t * qh, __m256i * values) const { + auto idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); + const __mmask16 * m16 = (const __mmask16 *)qh; + index_t idx; + idx.vec = _mm256_mask_add_epi16(idx_l, m16[0], idx_l, offset); + values[0] = _mm256_set_epi32(iq3s_grid[idx.val[ 7]], iq3s_grid[idx.val[ 6]], iq3s_grid[idx.val[ 5]], iq3s_grid[idx.val[ 4]], + iq3s_grid[idx.val[ 3]], iq3s_grid[idx.val[ 2]], iq3s_grid[idx.val[ 1]], iq3s_grid[idx.val[ 0]]); + values[1] = _mm256_set_epi32(iq3s_grid[idx.val[15]], iq3s_grid[idx.val[14]], iq3s_grid[idx.val[13]], iq3s_grid[idx.val[12]], + iq3s_grid[idx.val[11]], iq3s_grid[idx.val[10]], iq3s_grid[idx.val[ 9]], iq3s_grid[idx.val[ 8]]); + } + const __m256i offset = _mm256_set1_epi16(256); +}; +#else +struct IndexHelperIQ3S { + union index_t { + __m256i vec; + uint32_t val[8]; + }; + inline void make2(const uint8_t * qs, const uint8_t * qh, __m256i * values) const { + index_t idx; + auto idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)qs)); + auto idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[0]), idx_shift), idx_mask); + idx.vec = _mm256_or_si256(idx_h, idx_l); + values[0] = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]], + iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]); + idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qs+8))); + idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[1]), idx_shift), idx_mask); + idx.vec = _mm256_or_si256(idx_h, idx_l); + values[1] = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]], + iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]); + } + const __m256i idx_mask = _mm256_set1_epi32(256); + const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8); +}; +#endif + +template +static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + int nbl = n / QK_K; + auto smask = _mm256_set1_epi8(1); + union { __m256i vec; uint32_t val[8]; } helper; + union { __m128i vec; uint16_t val[8]; } hidx; + __m256 acc[nrc_y] = {}; + __m256i isum[nrc_y] = {}; + __m256i qx[4]; +#ifdef HAVE_FANCY_SIMD + __mmask32 mask[4]; +#endif + for (int ix = 0; ix < nrc_x; ix += 4) { + auto iq3 = (const block_iq3_s_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq3[ibl].d)); + auto d4 = _mm256_set_m128(dl, dl); + auto qs = iq3[ibl].qs; + auto qh = iq3[ibl].qh; + auto scale_bits = _mm_loadu_si128((const __m128i *)iq3[ibl].scales); + auto scales8 = MM256_SET_M128I(_mm_srli_epi16(scale_bits, 4), scale_bits); + helper.vec = _mm256_or_si256(_mm256_slli_epi16(_mm256_and_si256(scales8, _mm256_set1_epi8(0xf)), 1), _mm256_set1_epi8(1)); + for (int ib = 0; ib < QK_K/32; ++ib) { + auto qh32 = (const uint32_t *)qh; + auto idx_h = _mm_sllv_epi64(_mm_cvtepu8_epi16(_mm_set1_epi32(qh32[0])), _mm_set_epi64x(4, 8)); + for (int i = 0; i < 4; ++i) { + auto idx_l = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)(qs + 8*i))); + hidx.vec = _mm_or_si128(idx_l, _mm_and_si128(idx_h, _mm_set1_epi16(0x100))); idx_h = _mm_srli_epi16(idx_h, 1); + qx[i] = _mm256_set_epi32(iq3s_grid[hidx.val[7]], iq3s_grid[hidx.val[6]], iq3s_grid[hidx.val[5]], iq3s_grid[hidx.val[4]], + iq3s_grid[hidx.val[3]], iq3s_grid[hidx.val[2]], iq3s_grid[hidx.val[1]], iq3s_grid[hidx.val[0]]); + } + qs += 32; qh += 4; + auto signs128 = _mm_loadu_si128((const __m128i*)iq3[ibl].signs + ib); + auto signs = MM256_SET_M128I(_mm_srli_epi16(signs128, 4), signs128); +#ifdef HAVE_FANCY_SIMD + auto scales = _mm256_cvtepi8_epi32(_mm_set1_epi32(helper.val[ib])); + mask[0] = _mm256_cmpeq_epi8_mask(_mm256_and_si256(signs, smask), smask); signs = _mm256_srli_epi16(signs, 1); + mask[1] = _mm256_cmpeq_epi8_mask(_mm256_and_si256(signs, smask), smask); signs = _mm256_srli_epi16(signs, 1); + mask[2] = _mm256_cmpeq_epi8_mask(_mm256_and_si256(signs, smask), smask); signs = _mm256_srli_epi16(signs, 1); + mask[3] = _mm256_cmpeq_epi8_mask(_mm256_and_si256(signs, smask), smask); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib); + auto sumi = _mm256_setzero_si256(); + auto ys = _mm256_shuffle_epi32(y, 0x00); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_mask_sub_epi8(ys, mask[0], _mm256_setzero_si256(), ys)); + ys = _mm256_shuffle_epi32(y, 0x55); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_mask_sub_epi8(ys, mask[1], _mm256_setzero_si256(), ys)); + ys = _mm256_shuffle_epi32(y, 0xaa); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_mask_sub_epi8(ys, mask[2], _mm256_setzero_si256(), ys)); + ys = _mm256_shuffle_epi32(y, 0xff); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_mask_sub_epi8(ys, mask[3], _mm256_setzero_si256(), ys)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(sumi, scales)); + } +#else + auto scales16 = _mm256_cvtepi8_epi16(_mm_set1_epi32(helper.val[ib])); + auto scales = _mm256_unpacklo_epi16(scales16, scales16); + auto s1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); signs = _mm256_srli_epi16(signs, 1); + auto s2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); signs = _mm256_srli_epi16(signs, 1); + auto s3 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); signs = _mm256_srli_epi16(signs, 1); + auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib); + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), s1))); + sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), s2))); + sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), s3))); + sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), s4))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(scales, sumi)); + } +#endif + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm256_setzero_si256(); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + info.store(ix, iy, sum); + acc[iy] = _mm256_setzero_ps(); + } + } +} + +template +static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto mf = _mm256_set1_epi8(0xf); + auto m3 = _mm256_set1_epi8(0x30); +#ifndef HAVE_FANCY_SIMD + auto m1 = _mm256_set1_epi16(1); +#endif + int nbl = n / QK_K; + union { __m256i vec; uint32_t val[8]; } hd; + __m256 acc[nrc_y] = {}; + __m256i qx[4]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_q4_k_r4 * iq4 = (const block_q4_k_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ibl].d)); + auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl)); + auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1))); + if constexpr (nrc_y == 1) { + d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl))); + } + auto lbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l); + auto hbits128 = _mm_loadu_si128((const __m128i *)iq4[ibl].scales_h); + auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4)); + hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m3)); + auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m3)); + auto shuffle = _mm256_set1_epi64x(0x0000000400000000); + auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); + shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1)); + auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); + shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1)); + auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); + shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1)); + auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums); + acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]); + acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]); + acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]); + acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]); + } + for (int ib = 0; ib < QK_K/32; ++ib) { + auto scales_d = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib])))); + auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1); + qx[0] = _mm256_and_si256(bits1, mf); + qx[1] = _mm256_and_si256(bits2, mf); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits1, 4), mf); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits2, 4), mf); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); +#else + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)); +#endif + if constexpr (nrc_y == 1) { + acc[iy] = _mm256_fmadd_ps(scales_d, _mm256_cvtepi32_ps(sumi), acc[iy]); + } else { + float d8 = q8.scale(iy, ibl); + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales_d, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]); + } + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + acc[iy] = _mm256_setzero_ps(); + info.store(ix+0, iy, sum); + } + } +} + +#ifdef HAVE_FANCY_SIMD +template +static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + //mul_mat_q4_k_r4_q8_k_avx2(n, vx, bx, info, nrc_x); + if constexpr (nrc_y == 1){ + mul_mat_q4_k_r4_q8_k_avx2<1>(n, vx, bx, info, nrc_x); + } else { + GGML_ASSERT(nrc_x%8 == 0); + Q8 q8(info); + auto mf = _mm512_set1_epi8(0xf); + int nbl = n / QK_K; + using helper_t = union { __m512i vec; uint32_t val[16]; }; + helper_t hd, hm; + __m512 acc[nrc_y] = {}; + __m512i isum[nrc_y] = {}; + __m512i qx[4]; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q4_k_r4 * iq4l = (const block_q4_k_r4 *)((const char *)vx + (ix+0)*bx); + const block_q4_k_r4 * iq4h = (const block_q4_k_r4 *)((const char *)vx + (ix+4)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto d1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4l[ibl].d)); + auto d2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4h[ibl].d)); + auto dl = _mm256_castps256_ps128(d1); + auto ml = _mm256_extractf128_ps(d1, 1); + auto dh = _mm256_castps256_ps128(d2); + auto mh = _mm256_extractf128_ps(d2, 1); + auto d4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(dl, dl)), _mm256_set_m128(dh, dh), 1); + auto m4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(ml, ml)), _mm256_set_m128(mh, mh), 1); + m4 = _mm512_mul_ps(m4, _mm512_set1_ps(-0.5f)); + auto slbits_l = _mm256_loadu_si256((const __m256i *)iq4l[ibl].scales_l); + auto shbits_l = _mm256_loadu_si256((const __m256i *)iq4h[ibl].scales_l); + auto slb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_l), shbits_l, 1); + auto sld = _mm512_and_si512(slb, mf); + auto slm = _mm512_and_si512(_mm512_srli_epi16(slb, 4), mf); + auto slbits_h = _mm_loadu_si128((const __m128i *)iq4l[ibl].scales_h); + auto shbits_h = _mm_loadu_si128((const __m128i *)iq4h[ibl].scales_h); + auto slbits_h2 = MM256_SET_M128I(_mm_srli_epi16(slbits_h, 4), slbits_h); + auto shbits_h2 = MM256_SET_M128I(_mm_srli_epi16(shbits_h, 4), shbits_h); + auto shb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_h2), shbits_h2, 1); + auto shd = _mm512_and_si512(_mm512_slli_epi16(shb, 4), _mm512_set1_epi8(0x30)); + auto shm = _mm512_and_si512(_mm512_slli_epi16(shb, 2), _mm512_set1_epi8(0x30)); + hd.vec = _mm512_or_si512(sld, shd); + hm.vec = _mm512_or_si512(slm, shm); + for (int ib = 0; ib < QK_K/32; ++ib) { + auto scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+0])); + auto scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+8])); + auto iscales = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1); + scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+0])); + scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+8])); + auto iscales_m = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1); + auto scales_m = _mm512_mul_ps(m4, _mm512_cvtepi32_ps(iscales_m)); + auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+0)), + _mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+0), 1); + auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+1)), + _mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+1), 1); + qx[0] = _mm512_and_si512(bits1, mf); + qx[1] = _mm512_and_si512(bits2, mf); + qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), mf); + qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), mf); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + isum[iy] = _mm512_add_epi32(isum[iy], _mm512_mullo_epi32(iscales, sumi)); + float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib]; + acc[iy] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[iy]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d4, _mm512_set1_ps(q8.scale(iy, ibl))), _mm512_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm512_setzero_si512(); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 0), _mm512_extractf32x4_ps(acc[iy], 1)); + auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 2), _mm512_extractf32x4_ps(acc[iy], 3)); + info.store(ix+0, iy, sum1); + info.store(ix+4, iy, sum2); + acc[iy] = _mm512_setzero_ps(); + } + } + } +} +#else +template +static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + mul_mat_q4_k_r4_q8_k_avx2(n, vx, bx, info, nrc_x); +} +#endif + +template +static void mul_mat_q5_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto mf = _mm256_set1_epi8(0xf); + auto m10 = _mm256_set1_epi8(0x10); + auto m30 = _mm256_set1_epi8(0x30); +#ifndef HAVE_FANCY_SIMD + auto m1 = _mm256_set1_epi16(1); +#endif + int nbl = n / QK_K; + union { __m256i vec; uint32_t val[8]; } hd; + __m256 acc[nrc_y] = {}; + __m256i qx[4]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_q5_k_r4 * iq5 = (const block_q5_k_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5[ibl].d)); + auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl)); + auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1))); + if constexpr (nrc_y == 1) { + d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl))); + } + auto lbits = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales_l); + auto hbits128 = _mm_loadu_si128((const __m128i *)iq5[ibl].scales_h); + auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4)); + hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m30)); + auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m30)); + auto shuffle = _mm256_set1_epi64x(0x0000000400000000); + auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); + shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1)); + auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); + shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1)); + auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); + shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1)); + auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums); + acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]); + acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]); + acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]); + acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]); + } + for (int ib = 0; ib < QK_K/32; ++ib) { + auto scales_d = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib])))); + auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0); + auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1); + auto hbits128 = _mm_loadu_si128((const __m128i*)iq5[ibl].qh + ib); + auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4)); + qx[0] = _mm256_or_si256(_mm256_and_si256(lbits1, mf), _mm256_and_si256(m10, hbits)); + qx[1] = _mm256_or_si256(_mm256_and_si256(lbits2, mf), _mm256_and_si256(m10, _mm256_srli_epi16(hbits, 2))); + qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits1, 4), mf), _mm256_and_si256(m10, _mm256_srli_epi16(hbits, 1))); + qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits2, 4), mf), _mm256_and_si256(m10, _mm256_srli_epi16(hbits, 3))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); +#else + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)); +#endif + if constexpr (nrc_y == 1) { + acc[iy] = _mm256_fmadd_ps(scales_d, _mm256_cvtepi32_ps(sumi), acc[iy]); + } else { + float d8 = q8.scale(iy, ibl); + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales_d, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]); + } + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + acc[iy] = _mm256_setzero_ps(); + info.store(ix+0, iy, sum); + } + } +} + +#ifdef HAVE_FANCY_SIMD +template +static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + if constexpr (nrc_y == 1){ + mul_mat_q5_k_r4_q8_k_avx2<1>(n, vx, bx, info, nrc_x); + } else { + GGML_ASSERT(nrc_x%8 == 0); + Q8 q8(info); + auto mf = _mm512_set1_epi8(0xf); + auto m10 = _mm512_set1_epi8(0x10); + int nbl = n / QK_K; + using helper_t = union { __m512i vec; uint32_t val[16]; }; + helper_t hd, hm; + __m512 acc[nrc_y] = {}; + __m512i isum[nrc_y] = {}; + __m512i qx[4]; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q5_k_r4 * iq5l = (const block_q5_k_r4 *)((const char *)vx + (ix+0)*bx); + const block_q5_k_r4 * iq5h = (const block_q5_k_r4 *)((const char *)vx + (ix+4)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto d1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5l[ibl].d)); + auto d2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5h[ibl].d)); + auto dl = _mm256_castps256_ps128(d1); + auto ml = _mm256_extractf128_ps(d1, 1); + auto dh = _mm256_castps256_ps128(d2); + auto mh = _mm256_extractf128_ps(d2, 1); + auto d4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(dl, dl)), _mm256_set_m128(dh, dh), 1); + auto m4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(ml, ml)), _mm256_set_m128(mh, mh), 1); + m4 = _mm512_mul_ps(m4, _mm512_set1_ps(-0.5f)); + auto slbits_l = _mm256_loadu_si256((const __m256i *)iq5l[ibl].scales_l); + auto shbits_l = _mm256_loadu_si256((const __m256i *)iq5h[ibl].scales_l); + auto slb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_l), shbits_l, 1); + auto sld = _mm512_and_si512(slb, mf); + auto slm = _mm512_and_si512(_mm512_srli_epi16(slb, 4), mf); + auto slbits_h = _mm_loadu_si128((const __m128i *)iq5l[ibl].scales_h); + auto shbits_h = _mm_loadu_si128((const __m128i *)iq5h[ibl].scales_h); + auto slbits_h2 = MM256_SET_M128I(_mm_srli_epi16(slbits_h, 4), slbits_h); + auto shbits_h2 = MM256_SET_M128I(_mm_srli_epi16(shbits_h, 4), shbits_h); + auto shb = _mm512_inserti32x8(_mm512_castsi256_si512(slbits_h2), shbits_h2, 1); + auto shd = _mm512_and_si512(_mm512_slli_epi16(shb, 4), _mm512_set1_epi8(0x30)); + auto shm = _mm512_and_si512(_mm512_slli_epi16(shb, 2), _mm512_set1_epi8(0x30)); + hd.vec = _mm512_or_si512(sld, shd); + hm.vec = _mm512_or_si512(slm, shm); + for (int ib = 0; ib < QK_K/32; ++ib) { + auto scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+0])); + auto scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib+8])); + auto iscales = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1); + scales1 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+0])); + scales2 = _mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib+8])); + auto iscales_m = _mm512_inserti32x8(_mm512_castsi256_si512(scales1), scales2, 1); + auto scales_m = _mm512_mul_ps(m4, _mm512_cvtepi32_ps(iscales_m)); + auto lbits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[ibl].qs+2*ib+0)), + _mm256_loadu_si256((const __m256i *)iq5h[ibl].qs+2*ib+0), 1); + auto lbits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l[ibl].qs+2*ib+1)), + _mm256_loadu_si256((const __m256i *)iq5h[ibl].qs+2*ib+1), 1); + auto hbits1 = _mm_loadu_si128((const __m128i*)iq5l[ibl].qh+ib); + auto hbits2 = _mm_loadu_si128((const __m128i*)iq5h[ibl].qh+ib); + auto hbl = MM256_SET_M128I(hbits1, _mm_slli_epi16(hbits1, 4)); + auto hbh = MM256_SET_M128I(hbits2, _mm_slli_epi16(hbits2, 4)); + auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbl), hbh, 1); + qx[0] = _mm512_or_si512(_mm512_and_si512(lbits1, mf), _mm512_and_si512(m10, hbits)); + qx[1] = _mm512_or_si512(_mm512_and_si512(lbits2, mf), _mm512_and_si512(m10, _mm512_srli_epi16(hbits, 2))); + qx[2] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(lbits1, 4), mf), _mm512_and_si512(m10, _mm512_srli_epi16(hbits, 1))); + qx[3] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(lbits2, 4), mf), _mm512_and_si512(m10, _mm512_srli_epi16(hbits, 3))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + isum[iy] = _mm512_add_epi32(isum[iy], _mm512_mullo_epi32(iscales, sumi)); + float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib]; + acc[iy] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[iy]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d4, _mm512_set1_ps(q8.scale(iy, ibl))), _mm512_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm512_setzero_si512(); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 0), _mm512_extractf32x4_ps(acc[iy], 1)); + auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 2), _mm512_extractf32x4_ps(acc[iy], 3)); + info.store(ix+0, iy, sum1); + info.store(ix+4, iy, sum2); + acc[iy] = _mm512_setzero_ps(); + } + } + } +} +#else +template +static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + mul_mat_q5_k_r4_q8_k_avx2(n, vx, bx, info, nrc_x); +} +#endif + +template +static void mul_mat_q2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto mxf = _mm256_set1_epi8(0xf); + auto m03 = _mm256_set1_epi8(0x03); + static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15}; + auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff); +#ifdef HAVE_FANCY_SIMD + __m256i isum[nrc_y] = {}; +#else + auto m1 = _mm256_set1_epi16(1); +#endif + int nbl = n / QK_K; + __m256 acc[nrc_y] = {}; + __m256i qx[4]; + int8_t scales[64]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_q2_k_r4 * iq2 = (const block_q2_k_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dm = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq2[ibl].d)); + auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dm), _mm256_castps256_ps128(dm)); + auto m4 = _mm256_set_m128(_mm256_extractf128_ps(dm, 1), _mm256_extractf128_ps(dm, 1)); + m4 = _mm256_mul_ps(m4, _mm256_set1_ps(-1.f)); + auto all_scales1 = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales+0); + auto all_scales2 = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales+1); + auto scales1 = _mm256_and_si256(_mm256_srli_epi16(all_scales1, 4), mxf); + auto scales2 = _mm256_and_si256(_mm256_srli_epi16(all_scales2, 4), mxf); + { + auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row + auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row + auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row + auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row + auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9 + auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11 + auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13 + auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15 + for (int iy = 0; iy < nrc_y; ++iy) { + auto bsums = q8.load_bsums(iy, ibl); + auto sumi = _mm256_setzero_si256(); +#ifdef HAVE_FANCY_SIMD + sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00)); + sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55)); + sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa)); + sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff)); + auto d8 = _mm256_set1_ps(q8.scale(iy, ibl)); + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(m4, d8), _mm256_cvtepi32_ps(sumi), acc[iy]); +#else + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff))); + auto d8 = _mm256_set1_ps(q8.scale(iy, ibl)); + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(m4, d8), _mm256_cvtepi32_ps(sumi), acc[iy]); + if constexpr (nrc_y == 1) { + d4 = _mm256_mul_ps(d4, d8); + } +#endif + } + } + all_scales1 = _mm256_and_si256(all_scales1, mxf); + all_scales2 = _mm256_and_si256(all_scales2, mxf); + _mm256_storeu_si256((__m256i *)scales+0, all_scales1); + _mm256_storeu_si256((__m256i *)scales+1, all_scales2); + for (int ib = 0; ib < QK_K/32; ++ib) { + auto iscales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(scales + 8*ib))); +#ifndef HAVE_FANCY_SIMD + auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales)); +#endif + auto lb = _mm256_loadu_si256((const __m256i *)iq2[ibl].qs+ib); + qx[0] = _mm256_and_si256(lb, m03); + qx[1] = _mm256_and_si256(_mm256_srli_epi16(lb, 2), m03); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(lb, 4), m03); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(lb, 6), m03); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi)); +#else + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + // Quants are in 0...3, so we can add add up all of them as int16_t without overflowing + auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)); + if constexpr (nrc_y == 1) { + acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]); + } else { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]); + } +#endif + } + } +#ifdef HAVE_FANCY_SIMD + for (int iy = 0; iy < nrc_y; ++iy) { + auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))); + acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + isum[iy] = _mm256_setzero_si256(); + } +#endif + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + acc[iy] = _mm256_setzero_ps(); + info.store(ix+0, iy, sum); + } + } +} + +template +static void mul_mat_q3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto m4 = _mm256_set1_epi8(0xf); + auto m30 = _mm256_set1_epi8(0x30); + auto m32 = _mm256_set1_epi8(32); + auto m03 = _mm256_set1_epi8(0x03); + auto m04 = _mm256_set1_epi8(0x04); + static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15}; + auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff); +#ifdef HAVE_FANCY_SIMD + __m256i isum[nrc_y]; +#else + auto m1 = _mm256_set1_epi16(1); +#endif + int nbl = n / QK_K; + __m256 acc[nrc_y] = {}; + __m256i qx[4]; + int8_t scales[64]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_q3_k_r4 * iq3 = (const block_q3_k_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq3[ibl].d)); + auto d4 = _mm256_set_m128(dl, dl); +#ifndef HAVE_FANCY_SIMD + if constexpr (nrc_y == 1) { + d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl))); + } +#endif + auto slb = _mm256_loadu_si256((const __m256i *)iq3[ibl].scales_l); + auto shbits = _mm_loadu_si128((const __m128i *)iq3[ibl].scales_h); + auto shb = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits); + auto scales1 = _mm256_sub_epi8(_mm256_or_si256(_mm256_and_si256(slb, m4), _mm256_and_si256(_mm256_slli_epi16(shb, 4), m30)), m32); + auto scales2 = _mm256_sub_epi8(_mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(slb, 4), m4), _mm256_and_si256(shb, m30)), m32); + _mm256_storeu_si256((__m256i *)scales+0, scales1); + _mm256_storeu_si256((__m256i *)scales+1, scales2); + { +#ifndef HAVE_FANCY_SIMD + auto min = _mm256_mul_ps(d4, _mm256_set1_ps(-4.f)); +#endif + auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row + auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row + auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row + auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row + auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9 + auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11 + auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13 + auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15 +#ifdef HAVE_FANCY_SIMD + s1 = _mm256_mullo_epi16(s1, _mm256_set1_epi16(-4)); + s2 = _mm256_mullo_epi16(s2, _mm256_set1_epi16(-4)); + s3 = _mm256_mullo_epi16(s3, _mm256_set1_epi16(-4)); + s4 = _mm256_mullo_epi16(s4, _mm256_set1_epi16(-4)); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto bsums = q8.load_bsums(iy, ibl); + auto sumi = _mm256_setzero_si256(); +#ifdef HAVE_FANCY_SIMD + sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00)); + sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55)); + sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa)); + sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff)); + isum[iy] = sumi; +#else + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff))); + if constexpr (nrc_y == 1) { + acc[iy] = _mm256_fmadd_ps(min, _mm256_cvtepi32_ps(sumi), acc[iy]); + } else { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(min, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]); + } +#endif + } + } + for (int ib = 0; ib < QK_K/32; ++ib) { + auto iscales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(scales + 8*ib))); +#ifndef HAVE_FANCY_SIMD + auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales)); +#endif + auto lb = _mm256_loadu_si256((const __m256i *)iq3[ibl].qs+ib); + auto hbits = _mm_loadu_si128((const __m128i *)iq3[ibl].qh+ib); + auto hb = MM256_SET_M128I(hbits, _mm_slli_epi16(hbits, 4)); + qx[0] = _mm256_or_si256(_mm256_and_si256(lb, m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 2))); + qx[1] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 2), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 3))); + qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 4), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 4))); + qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 6), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 5))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi)); +#else + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + // Quants are in 0...8, so we can add add up all of them as int16_t without overflowing + auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)); + if constexpr (nrc_y == 1) { + acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]); + } else { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]); + } +#endif + + } + } +#ifdef HAVE_FANCY_SIMD + for (int iy = 0; iy < nrc_y; ++iy) { + auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))); + acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + } +#endif + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + acc[iy] = _mm256_setzero_ps(); + info.store(ix+0, iy, sum); + } + } +} + +template +static void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto m4 = _mm256_set1_epi8(0xf); + auto m3 = _mm256_set1_epi8(0x30); + static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15}; + auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff); +#ifdef HAVE_FANCY_SIMD + __m256i isum[nrc_y]; +#else + auto m1 = _mm256_set1_epi16(1); +#endif + int nbl = n / QK_K; + __m256 acc[nrc_y] = {}; + __m256i qx[4]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_q6_k_r4 * iq6 = (const block_q6_k_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6[ibl].d)); + auto d4 = _mm256_set_m128(dl, dl); +#ifndef HAVE_FANCY_SIMD + if constexpr (nrc_y == 1) { + d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl))); + } +#endif + { +#ifndef HAVE_FANCY_SIMD + auto min = _mm256_mul_ps(d4, _mm256_set1_ps(-32.f)); +#endif + auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+0)), shuff); // blocks 0, 1, 2, 3 for each row + auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+1)), shuff); // blocks 4, 5, 6, 7 for each row + auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+2)), shuff); // blocks 8, 9, 10, 11 for each row + auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+3)), shuff); // blocks 12, 13, 14, 15 for each row + auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9 + auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11 + auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13 + auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15 +#ifdef HAVE_FANCY_SIMD + s1 = _mm256_mullo_epi16(s1, _mm256_set1_epi16(-32)); + s2 = _mm256_mullo_epi16(s2, _mm256_set1_epi16(-32)); + s3 = _mm256_mullo_epi16(s3, _mm256_set1_epi16(-32)); + s4 = _mm256_mullo_epi16(s4, _mm256_set1_epi16(-32)); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto bsums = q8.load_bsums(iy, ibl); + auto sumi = _mm256_setzero_si256(); +#ifdef HAVE_FANCY_SIMD + sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00)); + sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55)); + sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa)); + sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff)); + isum[iy] = sumi; +#else + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff))); + if constexpr (nrc_y == 1) { + acc[iy] = _mm256_fmadd_ps(min, _mm256_cvtepi32_ps(sumi), acc[iy]); + } else { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(min, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]); + } +#endif + } + } + const uint32_t * scales = (const uint32_t *)iq6[ibl].scales; + for (int ib = 0; ib < QK_K/32; ++ib) { + auto iscales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(scales + 2*ib))); +#ifndef HAVE_FANCY_SIMD + auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales)); +#endif + auto lbits1 = _mm256_loadu_si256((const __m256i *)iq6[ibl].ql+2*ib+0); + auto lbits2 = _mm256_loadu_si256((const __m256i *)iq6[ibl].ql+2*ib+1); + auto hbits = _mm256_loadu_si256((const __m256i *)iq6[ibl].qh+ib); + qx[0] = _mm256_or_si256(_mm256_and_si256(lbits1, m4), _mm256_and_si256(m3, _mm256_slli_epi16(hbits, 4))); + qx[1] = _mm256_or_si256(_mm256_and_si256(lbits2, m4), _mm256_and_si256(m3, hbits)); + qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4), _mm256_and_si256(m3, _mm256_slli_epi16(hbits, 2))); + qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits2, 4), m4), _mm256_and_si256(m3, _mm256_srli_epi16(hbits, 2))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi)); +#else + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + // Quants are in 0...63, so we can add at most 4 as int16_t to be sure of no int16_t overflow + auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2)); + if constexpr (nrc_y == 1) { + acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]); + } else { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]); + } +#endif + } + } +#ifdef HAVE_FANCY_SIMD + for (int iy = 0; iy < nrc_y; ++iy) { + auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))); + acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + } +#endif + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + acc[iy] = _mm256_setzero_ps(); + info.store(ix+0, iy, sum); + } + } +} + +// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__) +template +static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); +#ifndef HAVE_FANCY_SIMD + auto m1 = _mm256_set1_epi16(1); +#endif + int nbl = n / QK_K; + __m256 acc[nrc_y] = {}; + __m256i isum[nrc_y] = {}; + __m256i qx[4]; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q8_k_r8 * iq8 = (const block_q8_k_r8 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto d4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ibl].d)); + for (int ib = 0; ib < QK_K/16; ++ib) { + qx[0] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+0); + qx[1] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+1); + qx[2] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+2); + qx[3] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+3); +#ifdef HAVE_FANCY_SIMD + qx[0] = _mm256_xor_si256(_mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+0), _mm256_set1_epi8(-128)); + qx[1] = _mm256_xor_si256(_mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+1), _mm256_set1_epi8(-128)); + qx[2] = _mm256_xor_si256(_mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+2), _mm256_set1_epi8(-128)); + qx[3] = _mm256_xor_si256(_mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+3), _mm256_set1_epi8(-128)); +#else + auto s0 = _mm256_sign_epi8(qx[0], qx[0]); + auto s1 = _mm256_sign_epi8(qx[1], qx[1]); + auto s2 = _mm256_sign_epi8(qx[2], qx[2]); + auto s3 = _mm256_sign_epi8(qx[3], qx[3]); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+ib); + auto y = MM256_SET_M128I(y128, y128); +#ifdef HAVE_FANCY_SIMD + isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[0], _mm256_shuffle_epi32(y, 0x00)); + isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[1], _mm256_shuffle_epi32(y, 0x55)); + isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa)); + isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[3], _mm256_shuffle_epi32(y, 0xff)); +#else + auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]))); + auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]))); + auto sumi3 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]))); + auto sumi4 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(sumi1, sumi2)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(sumi3, sumi4)); +#endif + } + } +#ifdef HAVE_FANCY_SIMD + auto m4 = _mm256_mul_ps(d4, _mm256_set1_ps(-128.f)); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))); + acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]); +#ifdef HAVE_FANCY_SIMD + auto bsums = (const float *)q8.y[iy][ibl].bsums; + acc[iy] = _mm256_fmadd_ps(m4, _mm256_set1_ps(bsums[0]), acc[iy]); +#endif + isum[iy] = _mm256_setzero_si256(); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = _mm256_setzero_ps(); + } + } +} + +#ifdef __AVX512BF16__ +template +static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%16 == 0); + const ggml_bf16_t * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy); + for (int ix = 0; ix < nrc_x/32; ++ix) { + __m512 acc[2*nrc_y] = {}; + __m512bh qx[8]; + const ggml_bf16_t * b8_1 = (const ggml_bf16_t *)((const char *)vx + (32*ix+ 0)*bx); + const ggml_bf16_t * b8_2 = (const ggml_bf16_t *)((const char *)vx + (32*ix+16)*bx); + for (int ib = 0; ib < n/8; ++ib) { + qx[0] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+0); + qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+1); + qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+2); + qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+3); + qx[4] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+0); + qx[5] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+1); + qx[6] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+2); + qx[7] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib); + //auto y = _mm512_broadcast_i32x4(y128); + auto y256 = MM256_SET_M128I(y128, y128); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1); + acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[0], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[4], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[5], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[6], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[7], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(32*ix+ 0, iy, acc[2*iy+0]); + info.store(32*ix+16, iy, acc[2*iy+1]); + } + } + for (int ix = 32*(nrc_x/32); ix < nrc_x; ix += 16) { + __m512 acc[nrc_y] = {}; + __m512bh qx[4]; + const ggml_bf16_t * b8 = (const ggml_bf16_t *)((const char *)vx + (ix+0)*bx); + for (int ib = 0; ib < n/8; ++ib) { + qx[0] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+0); + qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+1); + qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+2); + qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib); + auto y256 = MM256_SET_M128I(y128, y128); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1); + acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[0], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + } + } +} +#endif + +template +//IQK_ALWAYS_INLINE void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8& q8, __m256i shuff, +inline void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8& q8, __m256i shuff, + __m256i * isum, int16_t min) { + auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row + auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row + auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row + auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row + if constexpr (nrc_y == 1) { + auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9 + auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11 + auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13 + auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15 + auto sumi = _mm256_setzero_si256(); + auto bsums = q8.load_bsums(0, ibl); +#ifdef HAVE_FANCY_SIMD + sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00)); + sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55)); + sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa)); + sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff)); +#else + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff))); +#endif + isum[0] = _mm256_mullo_epi32(sumi, _mm256_set1_epi32(min)); + + } else { + auto s1 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0))); // blocks 0, 1, 8, 9 + auto s2 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1))); // blocks 2, 3, 10, 11 + auto s3 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0))); // blocks 4, 5, 12, 13 + auto s4 = _mm256_mullo_epi16(_mm256_set1_epi16(min), MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1))); // blocks 6, 7, 14, 15 + for (int iy = 0; iy < nrc_y; ++iy) { + auto bsums = q8.load_bsums(iy, ibl); +#ifdef HAVE_FANCY_SIMD + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s1, _mm256_shuffle_epi32(bsums, 0x00)); + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s2, _mm256_shuffle_epi32(bsums, 0x55)); + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s3, _mm256_shuffle_epi32(bsums, 0xaa)); + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s4, _mm256_shuffle_epi32(bsums, 0xff)); +#else + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff))); +#endif + } + } +} + +template +inline void iq2345_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8& q8, __m256i shuff, + __m256i extra, __m256i * isum, int8_t min, int8_t delta) { + auto mask = _mm256_set_epi64x(0x0808080808080808, 0x0404040404040404, 0x0202020202020202, 0x0101010101010101); + auto vdelta = _mm256_set1_epi8(delta); + auto vmin = _mm256_set1_epi8(min); + auto min1 = _mm256_add_epi8(vmin, _mm256_and_si256(vdelta, _mm256_cmpeq_epi8(_mm256_and_si256(extra, mask), mask))); + auto min2 = _mm256_add_epi8(vmin, _mm256_and_si256(vdelta, _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_srli_epi16(extra, 4), mask), mask))); + auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row + auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row + auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row + auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row + auto m1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min1, 0)), shuff); // blocks 0, 1, 2, 3 for each row + auto m2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min1, 1)), shuff); // blocks 4, 5, 6, 7 for each row + auto m3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min2, 0)), shuff); // blocks 8, 9, 10, 11 for each row + auto m4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min2, 1)), shuff); // blocks 12, 13, 14, 15 for each row + auto s1 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m3, 0), _mm256_extracti128_si256(m1, 0)), + MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0))); // blocks 0, 1, 8, 9 + auto s2 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m3, 1), _mm256_extracti128_si256(m1, 1)), + MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1))); // blocks 2, 3, 10, 11 + auto s3 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m4, 0), _mm256_extracti128_si256(m2, 0)), + MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0))); // blocks 4, 5, 12, 13 + auto s4 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m4, 1), _mm256_extracti128_si256(m2, 1)), + MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1))); // blocks 6, 7, 14, 15 + for (int iy = 0; iy < nrc_y; ++iy) { + auto bsums = q8.load_bsums(iy, ibl); +#ifdef HAVE_FANCY_SIMD + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s1, _mm256_shuffle_epi32(bsums, 0x00)); + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s2, _mm256_shuffle_epi32(bsums, 0x55)); + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s3, _mm256_shuffle_epi32(bsums, 0xaa)); + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s4, _mm256_shuffle_epi32(bsums, 0xff)); +#else + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff))); +#endif + } +} + +template +static void mul_mat_iq2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto m4 = _mm256_set1_epi8(0xf); + auto ms = _mm256_set1_epi8(4); + auto m03 = _mm256_set1_epi8(0x03); + auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000); + static const uint8_t kvalues_iq2nl[32] = {1, 19, 33, 49, 6, 24, 38, 54, 1, 19, 33, 49, 6, 24, 38, 54, 1, 19, 33, 49, 6, 24, 38, 54, 1, 19, 33, 49, 6, 24, 38, 54}; + auto values = _mm256_loadu_si256((const __m256i*)kvalues_iq2nl); + static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15}; + auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff); +#ifndef HAVE_FANCY_SIMD + auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100); +#endif + int nbl = n / QK_K; + __m256 acc[nrc_y] = {}; + __m256i qx[4]; + uint64_t stored_scales[8]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq2_k_r4 * iq2 = (const block_iq2_k_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq2[ibl].d)); + auto d4 = _mm256_set_m128(dl, dl); + auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq2[ibl].extra); + auto slbits = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales); + auto i8scales1 = _mm256_add_epi8(_mm256_and_si256(slbits, m4), _mm256_set1_epi8(-8)); + auto i8scales2 = _mm256_add_epi8(_mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4), _mm256_set1_epi8(-8)); + _mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1); + _mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2); + __m256i isum[nrc_y] = {}; + iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -32); + for (int ib = 0; ib < QK_K/32; ++ib) { +#ifdef HAVE_FANCY_SIMD + auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib))); +#else + auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle); +#endif + auto lb = _mm256_loadu_si256((const __m256i *)iq2[ibl].qs+ib); + auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 2)); extra = _mm256_srli_epi16(extra, 1); + shift = _mm256_shuffle_epi8(shift, shift_shuffle); + qx[0] = _mm256_and_si256(lb, m03); + qx[1] = _mm256_and_si256(_mm256_srli_epi16(lb, 2), m03); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(lb, 4), m03); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(lb, 6), m03); + qx[0] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[0], shift)); + qx[1] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[1], shift)); + qx[2] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[2], shift)); + qx[3] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[3], shift)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi)); +#else + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2))); +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + acc[iy] = _mm256_setzero_ps(); + info.store(ix+0, iy, sum); + } + } +} + +template +static void mul_mat_iq3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto m4 = _mm256_set1_epi8(0xf); + auto ms = _mm256_set1_epi8(8); + auto m03 = _mm256_set1_epi8(0x03); + auto m04 = _mm256_set1_epi8(0x04); + auto smask = _mm256_set_epi64x(0x0808080808080808, 0x0404040404040404, 0x0202020202020202, 0x0101010101010101); + auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000); + auto values128 = _mm_loadu_si128((const __m128i *)iq3nl_values); + auto values = MM256_SET_M128I(values128, values128); + values = _mm256_add_epi8(values, _mm256_set1_epi8(64)); + static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15}; + auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff); +#ifndef HAVE_FANCY_SIMD + auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100); +#endif + int nbl = n / QK_K; + __m256 acc[nrc_y] = {}; + __m256i qx[4]; + uint64_t stored_scales[8]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq3_k_r4 * iq3 = (const block_iq3_k_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq3[ibl].d)); + auto d4 = _mm256_set_m128(dl, dl); + auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq3[ibl].extra); + auto slbits = _mm256_loadu_si256((const __m256i *)iq3[ibl].scales_l); + auto sl1 = _mm256_add_epi8(_mm256_slli_epi16(_mm256_and_si256(slbits, m4), 1), _mm256_set1_epi8(1)); + auto sl2 = _mm256_add_epi8(_mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4), 1), _mm256_set1_epi8(1)); + auto sh = _mm256_set1_epi64x(((const uint64_t *)iq3[ibl].scales_h)[0]); + auto sh1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(sh, smask), smask), _mm256_set1_epi8(1)); + auto sh2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_srli_epi16(sh, 4), smask), smask), _mm256_set1_epi8(1)); + auto i8scales1 = _mm256_sign_epi8(sl1, sh1); + auto i8scales2 = _mm256_sign_epi8(sl2, sh2); + _mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1); + _mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2); + __m256i isum[nrc_y] = {}; + iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -64); + for (int ib = 0; ib < QK_K/32; ++ib) { +#ifdef HAVE_FANCY_SIMD + auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib))); +#else + auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle); +#endif + auto lb = _mm256_loadu_si256((const __m256i *)iq3[ibl].qs+ib); + auto hbits = _mm_loadu_si128((const __m128i *)iq3[ibl].qh+ib); + auto hb = MM256_SET_M128I(hbits, _mm_slli_epi16(hbits, 4)); + auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 3)); extra = _mm256_srli_epi16(extra, 1); + shift = _mm256_shuffle_epi8(shift, shift_shuffle); + qx[0] = _mm256_or_si256(_mm256_and_si256(lb, m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 2))); + qx[1] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 2), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 3))); + qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 4), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 4))); + qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 6), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 5))); + qx[0] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[0], shift)); + qx[1] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[1], shift)); + qx[2] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[2], shift)); + qx[3] = _mm256_shuffle_epi8(values, _mm256_add_epi8(qx[3], shift)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi)); +#else + auto sumi1 = _mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)); + auto sumi2 = _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)); + auto sumi3 = _mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)); + auto sumi4 = _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi3), _mm256_madd_epi16(scales, sumi4))); +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + acc[iy] = _mm256_setzero_ps(); + info.store(ix+0, iy, sum); + } + } +} + +template +static void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto m4 = _mm256_set1_epi8(0xf); + auto m30 = _mm256_set1_epi8(0x30); + auto m32 = _mm256_set1_epi8(32); + auto ms = _mm256_set1_epi8(4); + auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000); +#ifdef HAVE_FANCY_SIMD + auto values = load_iq4nl_values_256(); + static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15}; + auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff); +#else + auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100); + auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values); + auto values = MM256_SET_M128I(values128, values128); +#endif + int nbl = n / QK_K; + __m256 acc[nrc_y] = {}; + __m256i qx[4]; + uint64_t stored_scales[8]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq4_k_r4 * iq4 = (const block_iq4_k_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[ibl].d)); + auto d4 = _mm256_set_m128(dl, dl); + auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq4[ibl].extra); + auto slbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l); + auto sl1 = _mm256_and_si256(slbits, m4); + auto sl2 = _mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4); + auto shbits = _mm_loadu_si128((const __m128i*)iq4[ibl].scales_h); + auto sh = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits); + auto i8scales1 = _mm256_sub_epi8(_mm256_or_si256(sl1, _mm256_and_si256(m30, _mm256_slli_epi16(sh, 4))), m32); + auto i8scales2 = _mm256_sub_epi8(_mm256_or_si256(sl2, _mm256_and_si256(m30, sh)), m32); + _mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1); + _mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2); + __m256i isum[nrc_y] = {}; +#ifdef HAVE_FANCY_SIMD + iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -128); +#endif + for (int ib = 0; ib < QK_K/32; ++ib) { +#ifdef HAVE_FANCY_SIMD + auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib))); +#else + auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle); +#endif + auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1); + auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 2)); extra = _mm256_srli_epi16(extra, 1); + shift = _mm256_shuffle_epi8(shift, shift_shuffle); + qx[0] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4))); + qx[1] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4))); + qx[2] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4))); + qx[3] = _mm256_add_epi8(shift, _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4))); +#ifndef HAVE_FANCY_SIMD + auto s1 = _mm256_sign_epi8(qx[0], qx[0]); + auto s2 = _mm256_sign_epi8(qx[1], qx[1]); + auto s3 = _mm256_sign_epi8(qx[2], qx[2]); + auto s4 = _mm256_sign_epi8(qx[3], qx[3]); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi)); +#else + auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])); + auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])); + auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])); + auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi3), _mm256_madd_epi16(scales, sumi4))); +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + acc[iy] = _mm256_setzero_ps(); + info.store(ix+0, iy, sum); + } + } +} + +template +static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto m4 = _mm256_set1_epi8(0xf); + auto m30 = _mm256_set1_epi8(0x30); + auto m32 = _mm256_set1_epi8(32); + auto ms = _mm256_set1_epi8(2); + auto shift_shuffle = _mm256_set_epi64x(0x0707070706060606, 0x0505050504040404, 0x0303030302020202, 0x0101010100000000); + __m256i values[2]; + { + auto val1 = _mm_loadu_si128((const __m128i *)iq5nl_values+0); + auto val2 = _mm_loadu_si128((const __m128i *)iq5nl_values+1); + values[0] = MM256_SET_M128I(val1, val1); + values[1] = MM256_SET_M128I(val2, val2); +#ifdef HAVE_FANCY_SIMD + values[0] = _mm256_sub_epi8(values[0], _mm256_set1_epi8(-128)); + values[1] = _mm256_sub_epi8(values[1], _mm256_set1_epi8(-128)); +#endif + } +#ifdef HAVE_FANCY_SIMD + static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15}; + auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff); +#else + auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100); +#endif + int nbl = n / QK_K; + __m256 acc[nrc_y] = {}; + __m256i qx[4]; + uint64_t stored_scales[8]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq5_k_r4 * iq5 = (const block_iq5_k_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5[ibl].d)); + auto d4 = _mm256_set_m128(dl, dl); + auto extra = _mm256_set1_epi64x(*(const uint64_t *)iq5[ibl].extra); + auto slbits = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales_l); + auto sl1 = _mm256_and_si256(slbits, m4); + auto sl2 = _mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4); + auto shbits = _mm_loadu_si128((const __m128i*)iq5[ibl].scales_h); + auto sh = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits); + auto i8scales1 = _mm256_sub_epi8(_mm256_or_si256(sl1, _mm256_and_si256(m30, _mm256_slli_epi16(sh, 4))), m32); + auto i8scales2 = _mm256_sub_epi8(_mm256_or_si256(sl2, _mm256_and_si256(m30, sh)), m32); + _mm256_storeu_si256((__m256i *)stored_scales+0, i8scales1); + _mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2); + __m256i isum[nrc_y] = {}; +#ifdef HAVE_FANCY_SIMD + if constexpr (nrc_y == 1) { + iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -128); + } else { + iq2345_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, extra, isum, -128, 2); + } +#endif + for (int ib = 0; ib < QK_K/32; ++ib) { +#ifdef HAVE_FANCY_SIMD + auto scales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(stored_scales + ib))); +#else + auto scales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(stored_scales[ib])), s_shuffle); +#endif + auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0); + auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1); + auto hbits = _mm_loadu_si128((const __m128i *)iq5[ibl].qh+ib); + auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 2), hbits); + qx[0] = _mm256_and_si256(lbits1, m4); + qx[1] = _mm256_and_si256(lbits2, m4); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(lbits2, 4), m4); + +#ifdef HAVE_FANCY_SIMD + auto q5vl = _mm256_shuffle_epi8(values[0], qx[0]); + auto q5vh = _mm256_shuffle_epi8(values[1], qx[0]); + qx[0] = _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01)), q5vl, q5vh); + + q5vl = _mm256_shuffle_epi8(values[0], qx[1]); + q5vh = _mm256_shuffle_epi8(values[1], qx[1]); + qx[1] = _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x10)), _mm256_set1_epi8(0x10)), q5vl, q5vh); + + q5vl = _mm256_shuffle_epi8(values[0], qx[2]); + q5vh = _mm256_shuffle_epi8(values[1], qx[2]); + qx[2] = _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x02)), _mm256_set1_epi8(0x02)), q5vl, q5vh); + + q5vl = _mm256_shuffle_epi8(values[0], qx[3]); + q5vh = _mm256_shuffle_epi8(values[1], qx[3]); + qx[3] = _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x20)), _mm256_set1_epi8(0x20)), q5vl, q5vh); + + if constexpr (nrc_y == 1) { + auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1); + shift = _mm256_shuffle_epi8(shift, shift_shuffle); + qx[0] = _mm256_add_epi8(qx[0], shift); + qx[1] = _mm256_add_epi8(qx[1], shift); + qx[2] = _mm256_add_epi8(qx[2], shift); + qx[3] = _mm256_add_epi8(qx[3], shift); + } +#else + + auto q5vl = _mm256_shuffle_epi8(values[0], qx[0]); + auto q5vh = _mm256_shuffle_epi8(values[1], qx[0]); + qx[0] = _mm256_blendv_epi8(q5vl, q5vh, _mm256_cmpeq_epi8(_mm256_and_si256(hb, _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01))); + + q5vl = _mm256_shuffle_epi8(values[0], qx[1]); + q5vh = _mm256_shuffle_epi8(values[1], qx[1]); + qx[1] = _mm256_blendv_epi8(q5vl, q5vh, _mm256_cmpeq_epi8(_mm256_and_si256(hb, _mm256_set1_epi8(0x10)), _mm256_set1_epi8(0x10))); + + q5vl = _mm256_shuffle_epi8(values[0], qx[2]); + q5vh = _mm256_shuffle_epi8(values[1], qx[2]); + qx[2] = _mm256_blendv_epi8(q5vl, q5vh, _mm256_cmpeq_epi8(_mm256_and_si256(hb, _mm256_set1_epi8(0x02)), _mm256_set1_epi8(0x02))); + + q5vl = _mm256_shuffle_epi8(values[0], qx[3]); + q5vh = _mm256_shuffle_epi8(values[1], qx[3]); + qx[3] = _mm256_blendv_epi8(q5vl, q5vh, _mm256_cmpeq_epi8(_mm256_and_si256(hb, _mm256_set1_epi8(0x20)), _mm256_set1_epi8(0x20))); + + auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1); + shift = _mm256_shuffle_epi8(shift, shift_shuffle); + qx[0] = _mm256_add_epi8(qx[0], shift); + qx[1] = _mm256_add_epi8(qx[1], shift); + qx[2] = _mm256_add_epi8(qx[2], shift); + qx[3] = _mm256_add_epi8(qx[3], shift); + auto s1 = _mm256_sign_epi8(qx[0], qx[0]); + auto s2 = _mm256_sign_epi8(qx[1], qx[1]); + auto s3 = _mm256_sign_epi8(qx[2], qx[2]); + auto s4 = _mm256_sign_epi8(qx[3], qx[3]); +#endif + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales, sumi)); +#else + auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])); + auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])); + auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])); + auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi1), _mm256_madd_epi16(scales, sumi2))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(_mm256_madd_epi16(scales, sumi3), _mm256_madd_epi16(scales, sumi4))); +#endif + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + acc[iy] = _mm256_setzero_ps(); + info.store(ix+0, iy, sum); + } + } +} + template inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) { if (j == 0) { @@ -2105,15 +5576,17 @@ inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, cons } } +// TODO: find the bug that causes this to be called without HAVE_FANCY_SIMD, which triggers +// writing 4 vvalues into scales, which is of size 2. inline void set_scales_8_iq(int j, const __m256i& all_scales, __m256i * scales) { -#ifdef HAVE_FANCY_SIMD +//#ifdef HAVE_FANCY_SIMD auto shuffle = j == 0 ? _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100) : _mm256_set_epi64x(0x0b0a0b0a0b0a0b0a, 0x0908090809080908, 0x0b0a0b0a0b0a0b0a, 0x0908090809080908); scales[0] = _mm256_shuffle_epi8(all_scales, shuffle); scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(4))); -#else - set_scales_8(all_scales, j, scales); -#endif +//#else +// set_scales_8(all_scales, j, scales); +//#endif } inline void set_scales_16_iq(const __m256i& all_scales, __m256i * scales) { @@ -2498,57 +5971,17 @@ IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const template static void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); +#ifdef HAVE_FANCY_SIMD if constexpr (nrc_y == 1) { mul_mat_qX_K_q8_K_IQ_1(n, vx, bx, info, nrc_x); } else { mul_mat_qX_K_q8_K_IQ_N(n, vx, bx, info, nrc_x); } +#else + mul_mat_qX_K_q8_K_IQ_N(n, vx, bx, info, nrc_x); +#endif } -//#ifdef HAVE_FANCY_SIMD -// Strangely enough, the following implementation makes PP ~6% slower and TG ~6% faster -// compared to the vanilla AVX2 version below. -//struct IndexHelperIQ3S { -// union index_t { -// __m256i vec; -// uint16_t val[16]; -// }; -// inline void make2(const uint8_t * qs, const uint8_t * qh, __m256i * values) const { -// auto idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); -// const __mmask16 * m16 = (const __mmask16 *)qh; -// index_t idx; -// idx.vec = _mm256_mask_add_epi16(idx_l, m16[0], idx_l, offset); -// values[0] = _mm256_set_epi32(iq3s_grid[idx.val[ 7]], iq3s_grid[idx.val[ 6]], iq3s_grid[idx.val[ 5]], iq3s_grid[idx.val[ 4]], -// iq3s_grid[idx.val[ 3]], iq3s_grid[idx.val[ 2]], iq3s_grid[idx.val[ 1]], iq3s_grid[idx.val[ 0]]); -// values[1] = _mm256_set_epi32(iq3s_grid[idx.val[15]], iq3s_grid[idx.val[14]], iq3s_grid[idx.val[13]], iq3s_grid[idx.val[12]], -// iq3s_grid[idx.val[11]], iq3s_grid[idx.val[10]], iq3s_grid[idx.val[ 9]], iq3s_grid[idx.val[ 8]]); -// } -// const __m256i offset = _mm256_set1_epi16(256); -//}; -//#else -struct IndexHelperIQ3S { - union index_t { - __m256i vec; - uint32_t val[8]; - }; - inline void make2(const uint8_t * qs, const uint8_t * qh, __m256i * values) const { - index_t idx; - auto idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)qs)); - auto idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[0]), idx_shift), idx_mask); - idx.vec = _mm256_or_si256(idx_h, idx_l); - values[0] = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]], - iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]); - idx_l = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qs+8))); - idx_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[1]), idx_shift), idx_mask); - idx.vec = _mm256_or_si256(idx_h, idx_l); - values[1] = _mm256_set_epi32(iq3s_grid[idx.val[7]], iq3s_grid[idx.val[6]], iq3s_grid[idx.val[5]], iq3s_grid[idx.val[4]], - iq3s_grid[idx.val[3]], iq3s_grid[idx.val[2]], iq3s_grid[idx.val[1]], iq3s_grid[idx.val[0]]); - } - const __m256i idx_mask = _mm256_set1_epi32(256); - const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8); -}; -//#endif - struct DequantizerIQ3S final : public BaseDequantizer { DequantizerIQ3S(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} @@ -3372,10 +6805,10 @@ struct Q_Unpacker { } }; -struct Q8_0_x4_Unpacker { +struct Q8_0_x4_Unpacker_256 { using Sum4T = Sum4TypeQ80; inline static int block_size() { return QK8_0; } - Q8_0_x4_Unpacker(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {} + Q8_0_x4_Unpacker_256(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {} const char * cx_0; const block_q8_0_x4 * x; @@ -3401,6 +6834,44 @@ struct Q8_0_x4_Unpacker { } }; +#ifdef HAVE_FANCY_SIMD +struct Q8_0_x4_Unpacker_512 { + using Sum4T = Sum4TypeQ81; + inline static int block_size() { return QK8_0; } + Q8_0_x4_Unpacker_512(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {} + + const char * cx_0; + const block_q8_0_x4 * x; + size_t bx; + const __m128 min = _mm_set1_ps(-128.f); + + __m256i qx[4]; + + inline const __m256i* quants() const { return qx; } + + inline void set_row(int ix) { x = (const block_q8_0_x4 *)(cx_0 + ix*bx); } + + inline auto set_block_4(int i) { + auto scales = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x[i].d)); + for (int j = 0; j < 4; ++j) { + qx[j] = _mm256_loadu_si256((const __m256i *)x[i].qs + j); + qx[j] = _mm256_xor_si256(qx[j], _mm256_set1_epi8(0x80)); + } + return _mm256_set_m128(_mm_mul_ps(scales, min), scales); + } + inline auto set_block(int i) { + auto q8 = (const block_q8_0 *)(x + i); + qx[0] = _mm256_loadu_si256((const __m256i *)q8->qs); + qx[0] = _mm256_xor_si256(qx[0], _mm256_set1_epi8(0x80)); + float d = GGML_FP16_TO_FP32(q8->d); + return std::make_pair(d, -128.f*d); + } +}; +using Q8_0_x4_Unpacker = Q8_0_x4_Unpacker_512; +#else +using Q8_0_x4_Unpacker = Q8_0_x4_Unpacker_256; +#endif + struct Q8_0_Unpacker final : public Q_Unpacker { Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} using Sum4T = Sum4TypeQ80; @@ -3474,21 +6945,60 @@ struct QFBase { static inline Data load4Floats(const Float * x) { return _mm512_insertf32x4(_mm512_setzero_ps(), load128(x), 0); } + static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) { + acc = _mm512_fmadd_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00), acc); + acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc); + acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc); + acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc); + return acc; + } + static inline Acc acc_r4_first(const Data * xv, const Data& yv) { + auto acc = _mm512_mul_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00)); + acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc); + acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc); + acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc); + return acc; + } + static inline __m128 hsum_r4(Acc acc) { + auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 0), _mm512_extractf32x4_ps(acc, 1)); + auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 2), _mm512_extractf32x4_ps(acc, 3)); + return _mm_add_ps(sum1, sum2); + } #else constexpr static int k_step = 8; using Data = __m256; using Acc = __m256; static inline Data load(const ggml_half * x) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)x)); } static inline Data load(const float * x) { return _mm256_loadu_ps(x); } + static inline Data load(const ggml_bf16_t * x) { + return _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i*)x)), 16)); + } static inline Acc acc(Acc prev, const Data& y, const Data& x) { return _mm256_fmadd_ps(y, x, prev); } + static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) { + acc = _mm256_fmadd_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00), acc); + acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc); + acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc); + acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc); + return acc; + } + static inline Acc acc_r4_first(const Data * xv, const Data& yv) { + auto acc = _mm256_mul_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00)); + acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc); + acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc); + acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc); + return acc; + } static inline Acc acc_first(const Data& y, const Data& x) { return _mm256_mul_ps(y, x); } static inline float hsum(Acc acc) { return hsum_float_8(acc); } + static inline __m128 hsum_r4(Acc acc) { + return _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1)); + } template static inline Data load4Floats(const Float * x) { return _mm256_insertf128_ps(_mm256_setzero_ps(), load128(x), 0); @@ -3496,6 +7006,9 @@ struct QFBase { #endif static inline __m128 load128(const ggml_half * x) { return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x)); } static inline __m128 load128(const float * x) { return _mm_loadu_ps(x); } + static inline __m128 load128(const ggml_bf16_t * x) { + return _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i*)x)), 16)); + } }; template struct QFT final : public QFBase { constexpr static int nrc = nrc_in; @@ -3507,6 +7020,31 @@ template struct QFT final : public QFBase { } IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); } IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4Floats(y[iy] + 4*i); } + IQK_ALWAYS_INLINE void load_r4(int ix, int i, Data * xv) const { + xv[0] = load1(ix+0, i); + xv[1] = load1(ix+1, i); + xv[2] = load1(ix+2, i); + xv[3] = load1(ix+3, i); +#ifdef HAVE_FANCY_SIMD + auto t0 = _mm512_unpacklo_ps(xv[0], xv[1]); + auto t1 = _mm512_unpacklo_ps(xv[2], xv[3]); + auto t2 = _mm512_unpackhi_ps(xv[0], xv[1]); + auto t3 = _mm512_unpackhi_ps(xv[2], xv[3]); + xv[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1))); + xv[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1))); + xv[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3))); + xv[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3))); +#else + auto t0 = _mm256_unpacklo_ps(xv[0], xv[1]); + auto t1 = _mm256_unpacklo_ps(xv[2], xv[3]); + auto t2 = _mm256_unpackhi_ps(xv[0], xv[1]); + auto t3 = _mm256_unpackhi_ps(xv[2], xv[3]); + xv[0] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1))); + xv[1] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1))); + xv[2] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3))); + xv[3] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3))); +#endif + } const Float * y[nrc]; }; @@ -3552,6 +7090,56 @@ IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix])); } +template +inline void mul_mat_Qx_Qy_MxN_fa(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) { + int nb = n/QFBase::k_step; + Qy y(info); + Qx x(cx + ix0*bx, bx); + QFBase::Data xv[Qx::nrc]; + QFBase::Acc acc[Qx::nrc*Qy::nrc]; + auto yv = y.load1(0, 0); + for (int ix = 0; ix < Qx::nrc; ++ix) { + xv[ix] = x.load1(ix, 0); + acc[ix] = QFBase::acc_first(yv, xv[ix]); + } + for (int iy = 1; iy < Qy::nrc; ++iy) { + yv = y.load1(iy, 0); + for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc_first(yv, xv[ix]); + } + for (int i = 1; i < nb; ++i) { + yv = y.load1(0, i); + for (int ix = 0; ix < Qx::nrc; ++ix) { + xv[ix] = x.load1(ix, i); + acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]); + } + for (int iy = 1; iy < Qy::nrc; ++iy) { + yv = y.load1(iy, i); + for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]); + } + } + for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix])); +} + +template +inline void mul_mat_Qx_Qy_MxN_fa4(int D, const char * cx, size_t bx, int ix0, const DataInfo& info) { + static_assert(Qx::nrc%4 == 0); + int nb = D/QFBase::k_step; + Qy y(info); + Qx x(cx + ix0*bx, bx); + QFBase::Data xv[Qx::nrc]; + QFBase::Acc acc[Qx::nrc*Qy::nrc/4] = {}; + for (int i = 0; i < nb; ++i) { + for (int ix = 0; ix < Qx::nrc/4; ++ix) x.load_r4(4*ix, i, xv + 4*ix); + for (int iy = 0; iy < Qy::nrc; ++iy) { + auto yv = y.load1(iy, i); + for (int ix = 0; ix < Qx::nrc/4; ++ix) acc[ix*Qy::nrc + iy] = QFBase::acc_r4(acc[ix*Qy::nrc + iy], xv + 4*ix, yv); + } + } + for (int iy = 0; iy < Qy::nrc; ++iy) { + for (int ix = 0; ix < Qx::nrc/4; ++ix) info.store(ix0+4*ix, iy, QFBase::hsum_r4(acc[ix*Qy::nrc + iy])); + } +} + // This will handle any of f16 x f32, f32 x f16, f16 x f16, f32 x f32, with computations done // in f32 (i.e., f16 is first converted to f32). It is easy to extend to computations done in // f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now. @@ -3560,7 +7148,7 @@ void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, in #ifdef __AVX512F__ constexpr int k_nx = 5; #else - constexpr int k_nx = 2; + constexpr int k_nx = nrc_y == 1 ? 4 : 2; #endif const char * cx = (const char *)vx; for (int ix = 0; ix < nrc_x/k_nx; ++ix) { @@ -3569,14 +7157,26 @@ void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, in int last_x = k_nx*(nrc_x/k_nx); if (last_x == nrc_x) return; int nx = nrc_x - last_x; +#ifdef __AVX512F__ switch (nx) { case 1: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; -#ifdef __AVX512F__ case 2: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; case 3: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; case 4: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; -#endif } +#else + if constexpr (nrc_y == 1) { + switch (nx) { + case 1: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; + case 2: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; + case 3: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; + } + } else { + switch (nx) { + case 1: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; + } + } +#endif } #ifdef __AVX512BF16__ @@ -3585,7 +7185,8 @@ struct QFBaseBF16 { using Data = __m512bh; using Acc = __m512; static inline Data load(const ggml_bf16_t * x) { return __m512bh(_mm512_loadu_si512((const __m512i *)x)); } - static inline Acc acc(Acc prev, const Data& y, const Data& x) { + //static inline Acc acc(Acc prev, const Data& y, const Data& x) { + static inline Acc acc(Acc prev, Data y, Data x) { return _mm512_dpbf16_ps(prev, y, x); } static inline Acc acc_first(const Data& y, const Data& x) { @@ -3636,6 +7237,7 @@ IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, } for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QFBaseBF16::hsum(acc[nrc_x*iy+ix])); } + template void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { constexpr int k_nx = nrc_y <= 2 ? 8 : 5; @@ -3850,6 +7452,17 @@ void set_mul_mat_bf16(MulMat& mm) { mm.funcs[3] = mul_mat_fX_fY_T<4>; mm.funcs[4] = mul_mat_fX_fY_T<5>; } +void set_mul_mat_bf16_r16(MulMat& mm) { + for (auto& f : mm.funcs) f = nullptr; + mm.funcs[0] = mul_mat_bf16_r16_bf16<1>; + mm.funcs[1] = mul_mat_bf16_r16_bf16<2>; + mm.funcs[2] = mul_mat_bf16_r16_bf16<3>; + mm.funcs[3] = mul_mat_bf16_r16_bf16<4>; + mm.funcs[4] = mul_mat_bf16_r16_bf16<5>; + mm.funcs[5] = mul_mat_bf16_r16_bf16<6>; + mm.funcs[6] = mul_mat_bf16_r16_bf16<7>; + mm.funcs[7] = mul_mat_bf16_r16_bf16<8>; +} #endif bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { @@ -3861,6 +7474,20 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { switch (typeB) { #ifdef __AVX512BF16__ case GGML_TYPE_BF16: set_mul_mat_bf16(mm); break; +#else + case GGML_TYPE_BF16: set_mul_mat_f(mm); break; + case GGML_TYPE_F32: set_mul_mat_f(mm); break; +#endif + default: return false; + } + return true; + } + + if (typeA == GGML_TYPE_BF16_R16) { + if (ne00 % 16) return false; + switch (typeB) { +#ifdef __AVX512BF16__ + case GGML_TYPE_BF16: set_mul_mat_bf16_r16(mm); break; #endif default: return false; } @@ -3990,40 +7617,340 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { mm.funcs[7] = mul_mat_iq2bn_q8_K64<8>; expected_typeB = GGML_TYPE_Q8_K64; break; + case GGML_TYPE_IQ2_BN_R4: + assert (ne00 % QK_IQ1BN == 0); + mm.funcs[0] = mul_mat_iq2_bn_r4_q8_k16<1>; + mm.funcs[1] = mul_mat_iq2_bn_r4_q8_k16<2>; + mm.funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>; + mm.funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>; + mm.funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>; + mm.funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>; +//#ifdef HAVE_FANCY_SIMD + mm.funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>; + mm.funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>; +//#endif + expected_typeB = GGML_TYPE_Q8_K16; + break; case GGML_TYPE_Q4_0: assert (ne00 % QK4_0 == 0); MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_1; + expected_typeB = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_Q4_1: assert (ne00 % QK4_1 == 0); MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_1; + expected_typeB = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_Q5_0: assert (ne00 % QK5_0 == 0); MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_1; + expected_typeB = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_Q5_1: assert (ne00 % QK5_1 == 0); MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_1; + expected_typeB = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_Q6_0: assert (ne00 % QK6_0 == 0); MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_1; + expected_typeB = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_Q8_0: assert (ne00 % QK8_0 == 0); +#ifdef HAVE_FANCY_SIMD MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_1; + expected_typeB = GGML_TYPE_Q8_1_X4; +#else + MulMat::set_functions(mm); + expected_typeB = GGML_TYPE_Q8_0_X4; +#endif break; case GGML_TYPE_IQ4_NL: assert (ne00 % QK4_NL == 0); MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_1; + expected_typeB = GGML_TYPE_Q8_1_X4; + break; + case GGML_TYPE_IQ4_NL_R4: + assert (ne00 % QK4_NL == 0); + mm.funcs[0] = mul_mat_iq4_nl_r4_q8_1<1>; + mm.funcs[1] = mul_mat_iq4_nl_r4_q8_1<2>; + mm.funcs[2] = mul_mat_iq4_nl_r4_q8_1<3>; + mm.funcs[3] = mul_mat_iq4_nl_r4_q8_1<4>; + mm.funcs[4] = mul_mat_iq4_nl_r4_q8_1<5>; + mm.funcs[5] = mul_mat_iq4_nl_r4_q8_1<6>; + mm.funcs[6] = mul_mat_iq4_nl_r4_q8_1<7>; + mm.funcs[7] = mul_mat_iq4_nl_r4_q8_1<8>; + expected_typeB = GGML_TYPE_Q8_1_X4; + break; + case GGML_TYPE_IQ4_XS_R4: + assert (ne00 % QK_K == 0); + mm.funcs[0] = mul_mat_iq4_xs_r4_q8_k<1>; + mm.funcs[1] = mul_mat_iq4_xs_r4_q8_k<2>; + mm.funcs[2] = mul_mat_iq4_xs_r4_q8_k<3>; + mm.funcs[3] = mul_mat_iq4_xs_r4_q8_k<4>; + mm.funcs[4] = mul_mat_iq4_xs_r4_q8_k<5>; + mm.funcs[5] = mul_mat_iq4_xs_r4_q8_k<6>; + mm.funcs[6] = mul_mat_iq4_xs_r4_q8_k<7>; + mm.funcs[7] = mul_mat_iq4_xs_r4_q8_k<8>; + expected_typeB = GGML_TYPE_Q8_K32; + break; + case GGML_TYPE_IQ4_KS_R4: + assert (ne00 % QK_K == 0); + mm.funcs[0] = mul_mat_iq4_ks_r4_q8_k<1>; + mm.funcs[1] = mul_mat_iq4_ks_r4_q8_k<2>; + mm.funcs[2] = mul_mat_iq4_ks_r4_q8_k<3>; + mm.funcs[3] = mul_mat_iq4_ks_r4_q8_k<4>; + mm.funcs[4] = mul_mat_iq4_ks_r4_q8_k<5>; + mm.funcs[5] = mul_mat_iq4_ks_r4_q8_k<6>; + mm.funcs[6] = mul_mat_iq4_ks_r4_q8_k<7>; + mm.funcs[7] = mul_mat_iq4_ks_r4_q8_k<8>; +#ifndef HAVE_FANCY_SIMD + // For some reason Zen4 does not like this particular function + mm.func16 = mul_mat_iq4_ks_r4_q8_k<16>; +#endif + expected_typeB = GGML_TYPE_Q8_K32; + break; + case GGML_TYPE_IQ2_XXS_R4: + assert (ne00 % QK_K == 0); + mm.funcs[0] = mul_mat_iq2_xxs_r4_q8_k<1>; + mm.funcs[1] = mul_mat_iq2_xxs_r4_q8_k<2>; + mm.funcs[2] = mul_mat_iq2_xxs_r4_q8_k<3>; + mm.funcs[3] = mul_mat_iq2_xxs_r4_q8_k<4>; + mm.funcs[4] = mul_mat_iq2_xxs_r4_q8_k<5>; + mm.funcs[5] = mul_mat_iq2_xxs_r4_q8_k<6>; + mm.funcs[6] = mul_mat_iq2_xxs_r4_q8_k<7>; + mm.funcs[7] = mul_mat_iq2_xxs_r4_q8_k<8>; + mm.func16 = mul_mat_iq2_xxs_r4_q8_k<16>; + expected_typeB = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_IQ2_XS_R4: + assert (ne00 % QK_K == 0); + mm.funcs[0] = mul_mat_iq2_xs_r4_q8_k<1>; + mm.funcs[1] = mul_mat_iq2_xs_r4_q8_k<2>; + mm.funcs[2] = mul_mat_iq2_xs_r4_q8_k<3>; + mm.funcs[3] = mul_mat_iq2_xs_r4_q8_k<4>; + mm.funcs[4] = mul_mat_iq2_xs_r4_q8_k<5>; + mm.funcs[5] = mul_mat_iq2_xs_r4_q8_k<6>; + mm.funcs[6] = mul_mat_iq2_xs_r4_q8_k<7>; + mm.funcs[7] = mul_mat_iq2_xs_r4_q8_k<8>; +#ifndef HAVE_FANCY_SIMD + // For some reason Zen4 does not like this particular function + mm.func16 = mul_mat_iq2_xs_r4_q8_k_16; +#endif + expected_typeB = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_IQ2_S_R4: + assert (ne00 % QK_K == 0); + mm.funcs[0] = mul_mat_iq2_s_r4_q8_k<1>; + mm.funcs[1] = mul_mat_iq2_s_r4_q8_k<2>; + mm.funcs[2] = mul_mat_iq2_s_r4_q8_k<3>; + mm.funcs[3] = mul_mat_iq2_s_r4_q8_k<4>; + mm.funcs[4] = mul_mat_iq2_s_r4_q8_k<5>; + mm.funcs[5] = mul_mat_iq2_s_r4_q8_k<6>; + mm.funcs[6] = mul_mat_iq2_s_r4_q8_k<7>; + mm.funcs[7] = mul_mat_iq2_s_r4_q8_k<8>; + mm.func16 = mul_mat_iq2_s_r4_q8_k_16; + expected_typeB = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_IQ3_XXS_R4: + assert (ne00 % QK_K == 0); + mm.funcs[0] = mul_mat_iq3_xxs_r4_q8_k<1>; + mm.funcs[1] = mul_mat_iq3_xxs_r4_q8_k<2>; + mm.funcs[2] = mul_mat_iq3_xxs_r4_q8_k<3>; + mm.funcs[3] = mul_mat_iq3_xxs_r4_q8_k<4>; + mm.funcs[4] = mul_mat_iq3_xxs_r4_q8_k<5>; + mm.funcs[5] = mul_mat_iq3_xxs_r4_q8_k<6>; + mm.funcs[6] = mul_mat_iq3_xxs_r4_q8_k<7>; + mm.funcs[7] = mul_mat_iq3_xxs_r4_q8_k<8>; + mm.func16 = mul_mat_iq3_xxs_r4_q8_k<16>; + expected_typeB = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_IQ3_S_R4: + assert (ne00 % QK_K == 0); + mm.funcs[0] = mul_mat_iq3_s_r4_q8_k<1>; + mm.funcs[1] = mul_mat_iq3_s_r4_q8_k<2>; + mm.funcs[2] = mul_mat_iq3_s_r4_q8_k<3>; + mm.funcs[3] = mul_mat_iq3_s_r4_q8_k<4>; + mm.funcs[4] = mul_mat_iq3_s_r4_q8_k<5>; + mm.funcs[5] = mul_mat_iq3_s_r4_q8_k<6>; + mm.funcs[6] = mul_mat_iq3_s_r4_q8_k<7>; + mm.funcs[7] = mul_mat_iq3_s_r4_q8_k<8>; + mm.func16 = mul_mat_iq3_s_r4_q8_k<16>; + expected_typeB = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_Q2_K_R4: + assert (ne00 % QK_K == 0); + mm.funcs[0] = mul_mat_q2_k_r4_q8_k<1>; + mm.funcs[1] = mul_mat_q2_k_r4_q8_k<2>; + mm.funcs[2] = mul_mat_q2_k_r4_q8_k<3>; + mm.funcs[3] = mul_mat_q2_k_r4_q8_k<4>; + mm.funcs[4] = mul_mat_q2_k_r4_q8_k<5>; + mm.funcs[5] = mul_mat_q2_k_r4_q8_k<6>; + mm.funcs[6] = mul_mat_q2_k_r4_q8_k<7>; + mm.funcs[7] = mul_mat_q2_k_r4_q8_k<8>; + expected_typeB = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_Q3_K_R4: + assert (ne00 % QK_K == 0); + mm.funcs[0] = mul_mat_q3_k_r4_q8_k<1>; + mm.funcs[1] = mul_mat_q3_k_r4_q8_k<2>; + mm.funcs[2] = mul_mat_q3_k_r4_q8_k<3>; + mm.funcs[3] = mul_mat_q3_k_r4_q8_k<4>; + mm.funcs[4] = mul_mat_q3_k_r4_q8_k<5>; + mm.funcs[5] = mul_mat_q3_k_r4_q8_k<6>; + mm.funcs[6] = mul_mat_q3_k_r4_q8_k<7>; + mm.funcs[7] = mul_mat_q3_k_r4_q8_k<8>; + expected_typeB = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_Q4_K_R4: + assert (ne00 % QK_K == 0); + mm.funcs[0] = mul_mat_q4_k_r4_q8_k<1>; + mm.funcs[1] = mul_mat_q4_k_r4_q8_k<2>; + mm.funcs[2] = mul_mat_q4_k_r4_q8_k<3>; + mm.funcs[3] = mul_mat_q4_k_r4_q8_k<4>; + mm.funcs[4] = mul_mat_q4_k_r4_q8_k<5>; + mm.funcs[5] = mul_mat_q4_k_r4_q8_k<6>; + mm.funcs[6] = mul_mat_q4_k_r4_q8_k<7>; + mm.funcs[7] = mul_mat_q4_k_r4_q8_k<8>; + expected_typeB = GGML_TYPE_Q8_K32; + break; + case GGML_TYPE_Q5_K_R4: + assert (ne00 % QK_K == 0); + mm.funcs[0] = mul_mat_q5_k_r4_q8_k<1>; + mm.funcs[1] = mul_mat_q5_k_r4_q8_k<2>; + mm.funcs[2] = mul_mat_q5_k_r4_q8_k<3>; + mm.funcs[3] = mul_mat_q5_k_r4_q8_k<4>; + mm.funcs[4] = mul_mat_q5_k_r4_q8_k<5>; + mm.funcs[5] = mul_mat_q5_k_r4_q8_k<6>; + mm.funcs[6] = mul_mat_q5_k_r4_q8_k<7>; + mm.funcs[7] = mul_mat_q5_k_r4_q8_k<8>; + expected_typeB = GGML_TYPE_Q8_K32; + break; + case GGML_TYPE_Q6_K_R4: + assert (ne00 % QK_K == 0); + mm.funcs[0] = mul_mat_q6_k_r4_q8_k<1>; + mm.funcs[1] = mul_mat_q6_k_r4_q8_k<2>; + mm.funcs[2] = mul_mat_q6_k_r4_q8_k<3>; + mm.funcs[3] = mul_mat_q6_k_r4_q8_k<4>; + mm.funcs[4] = mul_mat_q6_k_r4_q8_k<5>; + mm.funcs[5] = mul_mat_q6_k_r4_q8_k<6>; + mm.funcs[6] = mul_mat_q6_k_r4_q8_k<7>; + mm.funcs[7] = mul_mat_q6_k_r4_q8_k<8>; + expected_typeB = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_Q8_K_R8: + assert (ne00 % QK_K == 0); + mm.funcs[0] = mul_mat_q8_k_r8_q8_k<1>; + mm.funcs[1] = mul_mat_q8_k_r8_q8_k<2>; + mm.funcs[2] = mul_mat_q8_k_r8_q8_k<3>; + mm.funcs[3] = mul_mat_q8_k_r8_q8_k<4>; + mm.funcs[4] = mul_mat_q8_k_r8_q8_k<5>; + mm.funcs[5] = mul_mat_q8_k_r8_q8_k<6>; + mm.funcs[6] = mul_mat_q8_k_r8_q8_k<7>; + mm.funcs[7] = mul_mat_q8_k_r8_q8_k<8>; + expected_typeB = GGML_TYPE_Q8_KR8; + break; + case GGML_TYPE_IQ4_K_R4: + assert (ne00 % QK_K == 0); + mm.funcs[0] = mul_mat_iq4_k_r4_q8_k<1>; + mm.funcs[1] = mul_mat_iq4_k_r4_q8_k<2>; + mm.funcs[2] = mul_mat_iq4_k_r4_q8_k<3>; + mm.funcs[3] = mul_mat_iq4_k_r4_q8_k<4>; + mm.funcs[4] = mul_mat_iq4_k_r4_q8_k<5>; + mm.funcs[5] = mul_mat_iq4_k_r4_q8_k<6>; + mm.funcs[6] = mul_mat_iq4_k_r4_q8_k<7>; + mm.funcs[7] = mul_mat_iq4_k_r4_q8_k<8>; + mm.func16 = mul_mat_iq4_k_r4_q8_k<16>; + expected_typeB = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_IQ5_K_R4: + assert (ne00 % QK_K == 0); + mm.funcs[0] = mul_mat_iq5_k_r4_q8_k<1>; + mm.funcs[1] = mul_mat_iq5_k_r4_q8_k<2>; + mm.funcs[2] = mul_mat_iq5_k_r4_q8_k<3>; + mm.funcs[3] = mul_mat_iq5_k_r4_q8_k<4>; + mm.funcs[4] = mul_mat_iq5_k_r4_q8_k<5>; + mm.funcs[5] = mul_mat_iq5_k_r4_q8_k<6>; + mm.funcs[6] = mul_mat_iq5_k_r4_q8_k<7>; + mm.funcs[7] = mul_mat_iq5_k_r4_q8_k<8>; + mm.func16 = mul_mat_iq5_k_r4_q8_k<16>; + expected_typeB = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_IQ2_K_R4: + assert (ne00 % QK_K == 0); + mm.funcs[0] = mul_mat_iq2_k_r4_q8_k<1>; + mm.funcs[1] = mul_mat_iq2_k_r4_q8_k<2>; + mm.funcs[2] = mul_mat_iq2_k_r4_q8_k<3>; + mm.funcs[3] = mul_mat_iq2_k_r4_q8_k<4>; + mm.funcs[4] = mul_mat_iq2_k_r4_q8_k<5>; + mm.funcs[5] = mul_mat_iq2_k_r4_q8_k<6>; + mm.funcs[6] = mul_mat_iq2_k_r4_q8_k<7>; + mm.funcs[7] = mul_mat_iq2_k_r4_q8_k<8>; + expected_typeB = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_IQ3_K_R4: + assert (ne00 % QK_K == 0); + mm.funcs[0] = mul_mat_iq3_k_r4_q8_k<1>; + mm.funcs[1] = mul_mat_iq3_k_r4_q8_k<2>; + mm.funcs[2] = mul_mat_iq3_k_r4_q8_k<3>; + mm.funcs[3] = mul_mat_iq3_k_r4_q8_k<4>; + mm.funcs[4] = mul_mat_iq3_k_r4_q8_k<5>; + mm.funcs[5] = mul_mat_iq3_k_r4_q8_k<6>; + mm.funcs[6] = mul_mat_iq3_k_r4_q8_k<7>; + mm.funcs[7] = mul_mat_iq3_k_r4_q8_k<8>; +#ifdef HAVE_FANCY_SIMD + mm.func16 = mul_mat_iq3_k_r4_q8_k<16>; +#endif + expected_typeB = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_Q4_0_R4: + assert (ne00 % QK4_NL == 0); + mm.funcs[0] = mul_mat_q4_0_r4_q8_1<1>; + mm.funcs[1] = mul_mat_q4_0_r4_q8_1<2>; + mm.funcs[2] = mul_mat_q4_0_r4_q8_1<3>; + mm.funcs[3] = mul_mat_q4_0_r4_q8_1<4>; + mm.funcs[4] = mul_mat_q4_0_r4_q8_1<5>; + mm.funcs[5] = mul_mat_q4_0_r4_q8_1<6>; + mm.funcs[6] = mul_mat_q4_0_r4_q8_1<7>; + mm.funcs[7] = mul_mat_q4_0_r4_q8_1<8>; + expected_typeB = GGML_TYPE_Q8_1_X4; + break; + case GGML_TYPE_Q5_0_R4: + assert (ne00 % QK4_NL == 0); + mm.funcs[0] = mul_mat_q5_0_r4_q8_1<1>; + mm.funcs[1] = mul_mat_q5_0_r4_q8_1<2>; + mm.funcs[2] = mul_mat_q5_0_r4_q8_1<3>; + mm.funcs[3] = mul_mat_q5_0_r4_q8_1<4>; + mm.funcs[4] = mul_mat_q5_0_r4_q8_1<5>; + mm.funcs[5] = mul_mat_q5_0_r4_q8_1<6>; + mm.funcs[6] = mul_mat_q5_0_r4_q8_1<7>; + mm.funcs[7] = mul_mat_q5_0_r4_q8_1<8>; + expected_typeB = GGML_TYPE_Q8_1_X4; + break; + case GGML_TYPE_Q6_0_R4: + assert (ne00 % QK4_NL == 0); + mm.funcs[0] = mul_mat_q6_0_r4_q8_1<1>; + mm.funcs[1] = mul_mat_q6_0_r4_q8_1<2>; + mm.funcs[2] = mul_mat_q6_0_r4_q8_1<3>; + mm.funcs[3] = mul_mat_q6_0_r4_q8_1<4>; + mm.funcs[4] = mul_mat_q6_0_r4_q8_1<5>; + mm.funcs[5] = mul_mat_q6_0_r4_q8_1<6>; + mm.funcs[6] = mul_mat_q6_0_r4_q8_1<7>; + mm.funcs[7] = mul_mat_q6_0_r4_q8_1<8>; + expected_typeB = GGML_TYPE_Q8_1_X4; + break; + case GGML_TYPE_Q8_0_R4: + assert (ne00 % QK4_NL == 0); + mm.funcs[0] = mul_mat_q8_0_r4_q8_1<1>; + mm.funcs[1] = mul_mat_q8_0_r4_q8_1<2>; + mm.funcs[2] = mul_mat_q8_0_r4_q8_1<3>; + mm.funcs[3] = mul_mat_q8_0_r4_q8_1<4>; + mm.funcs[4] = mul_mat_q8_0_r4_q8_1<5>; + mm.funcs[5] = mul_mat_q8_0_r4_q8_1<6>; + mm.funcs[6] = mul_mat_q8_0_r4_q8_1<7>; + mm.funcs[7] = mul_mat_q8_0_r4_q8_1<8>; + expected_typeB = GGML_TYPE_Q8_1_X4; break; default: @@ -5193,7 +9120,7 @@ struct DequantizerIQ3S final : public BaseDequantizer { }; -template +template void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); const int nb = n / QK_K; @@ -6345,6 +10272,135 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn } } +template struct Q8_16 { + + constexpr static int nrc_y = nrc; + + Q8_16(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto ptr = (const float *)info.src1_row(iy); + std::memcpy(d + 5*iy, ptr, 5*sizeof(float)); + y[iy] = (const int8_t *)(ptr + 5); + } + } + + inline int8x16x4_t load_quants(int iy, int i) const { return vld1q_s8_x4(y[iy] + 64*i); } + inline int8x16x2_t load_quants_32(int iy, int i) const { return vld1q_s8_x2(y[iy] + 32*i); } + inline float scale(int iy, int k) const { return d[5*iy+k]; } + inline float sum_row(int iy) const { return d[5*iy + 4]; } + inline float32x4_t scale(int iy) const { return vld1q_f32(d + 5*iy); } + + float d[5*nrc_y]; + const int8_t * y[nrc_y]; +}; + +template +static IQK_NOINLINE void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + if (nrc_x%4) { + printf("%s: %d is not a multiple of 4\n", __func__, nrc_x); + GGML_ABORT("fatal error"); + } + Q8_16 q8(info); + auto m3 = vdupq_n_u8(0x3); + int nb = n / QK_IQ1BN; + if constexpr (nrc_y == 1) { + auto mc = vdupq_n_u8(0xc); + int32x4_t acc[8]; + for (int ix = 0; ix < nrc_x; ix += 4) { + for (int k = 0; k < 8; ++k) acc[k] = vdupq_n_s32(0); + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto dl = vld1q_f32(dptr); + const uint8_t * iq2 = (const uint8_t *)(dptr + 4); + for (int ib = 0; ib < nb; ++ib) { + auto y = q8.load_quants(0, ib); + for (int j = 0; j < 4; ++j) { + auto bits1 = vld1q_u8(iq2 + 64*ib + 16*j); + auto bits2 = vshrq_n_u8(bits1, 4); + acc[2*j+0] = vdotq_laneq_s32(acc[2*j+0], vandq_u8(bits1, m3), y.val[j], 0); + acc[2*j+1] = vdotq_laneq_s32(acc[2*j+1], vandq_u8(bits1, mc), y.val[j], 1); + acc[2*j+0] = vdotq_laneq_s32(acc[2*j+0], vandq_u8(bits2, m3), y.val[j], 2); + acc[2*j+1] = vdotq_laneq_s32(acc[2*j+1], vandq_u8(bits2, mc), y.val[j], 3); + } + } + auto dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 0))); + auto sumf1 = vmulq_f32( vcvtq_f32_s32(acc[0]), dy); + auto sumf2 = vmulq_f32( vcvtq_f32_s32(acc[1]), dy); + dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 1))); + sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[2]), dy); + sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[3]), dy); + dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 2))); + sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[4]), dy); + sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[5]), dy); + dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 3))); + sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[6]), dy); + sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[7]), dy); + auto sumf = vfmaq_f32(sumf1, vdupq_n_f32(0.25f), sumf2); + sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(0))); + info.store(ix, 0, sumf); + } + } else { + int32x4_t acc[4*nrc_y] = {}; + uint8x16_t qx[8]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto dl = vld1q_f32(dptr); + const uint8_t * iq2 = (const uint8_t *)(dptr + 4); + for (int ib = 0; ib < nb; ++ib) { + auto bits = vld1q_u8_x2(iq2 + 64*ib); + qx[0] = vandq_u8(bits.val[0], m3); + qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3); + qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3); + qx[3] = vshrq_n_u8(bits.val[0], 6); + qx[4] = vandq_u8(bits.val[1], m3); + qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3); + qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3); + qx[7] = vshrq_n_u8(bits.val[1], 6); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants_32(iy, 2*ib+0); + acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[0], y.val[0], 0); + acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[1], y.val[0], 1); + acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[2], y.val[0], 2); + acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[3], y.val[0], 3); + acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[4], y.val[1], 0); + acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[5], y.val[1], 1); + acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[6], y.val[1], 2); + acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[7], y.val[1], 3); + } + bits = vld1q_u8_x2(iq2 + 64*ib + 32); + qx[0] = vandq_u8(bits.val[0], m3); + qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3); + qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3); + qx[3] = vshrq_n_u8(bits.val[0], 6); + qx[4] = vandq_u8(bits.val[1], m3); + qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3); + qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3); + qx[7] = vshrq_n_u8(bits.val[1], 6); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants_32(iy, 2*ib+1); + acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[0], y.val[0], 0); + acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[1], y.val[0], 1); + acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[2], y.val[0], 2); + acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[3], y.val[0], 3); + acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[4], y.val[1], 0); + acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[5], y.val[1], 1); + acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[6], y.val[1], 2); + acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[7], y.val[1], 3); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + float32x4_t sumf = vmulq_f32(vcvtq_f32_s32(acc[4*iy+0]), vmulq_laneq_f32(dl, dy, 0)); + sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+1]), vmulq_laneq_f32(dl, dy, 1)); + sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+2]), vmulq_laneq_f32(dl, dy, 2)); + sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+3]), vmulq_laneq_f32(dl, dy, 3)); + sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(iy))); + info.store(ix, iy, sumf); + acc[4*iy+0] = acc[4*iy+1] = acc[4*iy+2] = acc[4*iy+3] = vdupq_n_s32(0); + } + } + } +} + template static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { const int nb = n / QK_IQ1BN; @@ -6427,38 +10483,1678 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn } } +IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16x2_t& y) { + auto sumi = vdupq_n_s32(0); + sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0); + sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0); + sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1); + sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1); + sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2); + sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2); + sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3); + sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3); + return sumi; +} + +IQK_ALWAYS_INLINE int32x4x2_t interleaved_dotq_b16(const int8x16_t * qx, const int8x16x2_t& y) { + int32x4x2_t sumi = { vdupq_n_s32(0), vdupq_n_s32(0) }; + sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[0], y.val[0], 0); + sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[1], y.val[1], 0); + sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[2], y.val[0], 1); + sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[3], y.val[1], 1); + sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[4], y.val[0], 2); + sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[5], y.val[1], 2); + sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[6], y.val[0], 3); + sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[7], y.val[1], 3); + return sumi; +} + +IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16_t& y) { + auto sumi = vdupq_n_s32(0); + sumi = vdotq_laneq_s32(sumi, qx[0], y, 0); + sumi = vdotq_laneq_s32(sumi, qx[1], y, 1); + sumi = vdotq_laneq_s32(sumi, qx[2], y, 2); + sumi = vdotq_laneq_s32(sumi, qx[3], y, 3); + return sumi; +} + +IQK_ALWAYS_INLINE void prepare_iq4_nl_quants(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x4_t& bits, int8x16_t * qx) { + qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows + qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19 + qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7 + qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23 + qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11 + qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27 + qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15 + qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31 +} + +template +void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto m4 = vdupq_n_u8(0xf); + auto m3 = vdupq_n_u8(0x30); + auto m32 = vdupq_n_s8(-32); + auto values = vld1q_s8(iq4k_values); + int nbl = n / QK_K; + int8x16_t qx[8]; + int8x16x2_t iscales; + int32x4x4_t scales; + float32x4_t acc[nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq4_xs_r4 * iq4 = (const block_iq4_xs_r4 *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ibl].d)); + auto sl = vld1q_u8(iq4[ibl].scales_l); + auto sh8 = vld1_u8(iq4[ibl].scales_h); + auto sh = vcombine_u8(sh8, vshr_n_u8(sh8, 2)); + iscales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl, m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32); + iscales.val[1] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl, 4), vandq_u8(sh, m3)), m32); + int32x4_t isum[nrc_y] = {}; + for (int is = 0; is < 2; ++is) { + auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is])); + auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is])); + scales.val[0] = vmovl_s16(vget_low_s16(iscales16_1)); + scales.val[1] = vmovl_s16(vget_high_s16(iscales16_1)); + scales.val[2] = vmovl_s16(vget_low_s16(iscales16_2)); + scales.val[3] = vmovl_s16(vget_high_s16(iscales16_2)); + for (int ib = 0; ib < 4; ++ib) { + auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib); + prepare_iq4_nl_quants(values, m4, bits, qx); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template +void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto m4 = vdupq_n_u8(0xf); + auto values = vld1q_s8(iq4k_values); + int nbl = n / QK_K; + int8x16_t qx[8]; + int16x8x4_t iscales; + int32x4x4_t scales; + float32x4_t acc[nrc_y] = {}; + int32x4_t isum[nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto dptr = (const float *)((const char *)vx + ix*bx); + auto d4 = vld1q_f32(dptr); + const block_iq4_ks_r4 * iq4 = (const block_iq4_ks_r4 *)(dptr + 4); + for (int ibl = 0; ibl < nbl; ++ibl) { + auto sas = vld1q_u8_x2(iq4[ibl].scales); + auto scale = vandq_u8(sas.val[0], vdupq_n_u8(254)); + iscales.val[0] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127)); + iscales.val[1] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127)); + scale = vandq_u8(sas.val[1], vdupq_n_u8(254)); + iscales.val[2] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127)); + iscales.val[3] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127)); + // Adding the block shifts costs us ~9% in performance drop. + // Is there a better way? + sas.val[0] = vshlq_n_u8(vandq_u8(sas.val[0], vdupq_n_u8(1)), 2); + sas.val[1] = vshlq_n_u8(vandq_u8(sas.val[1], vdupq_n_u8(1)), 2); + { + auto s16_1 = vmulq_s16(iscales.val[0], vmovl_u8(vget_low_u8 (sas.val[0]))); + auto s16_2 = vmulq_s16(iscales.val[1], vmovl_u8(vget_high_u8(sas.val[0]))); + auto s16_3 = vmulq_s16(iscales.val[2], vmovl_u8(vget_low_u8 (sas.val[1]))); + auto s16_4 = vmulq_s16(iscales.val[3], vmovl_u8(vget_high_u8(sas.val[1]))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto bsums = vld1q_s16_x2(q8.y[iy][ibl].bsums); + auto bs = vpaddq_s16(bsums.val[0], bsums.val[1]); + auto b8 = vget_low_s16(bs); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3); + b8 = vget_high_s16(bs); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3); + } + } + for (int is = 0; is < 2; ++is) { + scales.val[0] = vmovl_s16(vget_low_s16 (iscales.val[2*is+0])); + scales.val[1] = vmovl_s16(vget_high_s16(iscales.val[2*is+0])); + scales.val[2] = vmovl_s16(vget_low_s16 (iscales.val[2*is+1])); + scales.val[3] = vmovl_s16(vget_high_s16(iscales.val[2*is+1])); + for (int ib = 0; ib < 4; ++ib) { + auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib); + prepare_iq4_nl_quants(values, m4, bits, qx); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.scale(iy, ibl)), vcvtq_f32_s32(isum[iy])); + isum[iy] = vdupq_n_s32(0); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vmulq_f32(d4, acc[iy])); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template +static void mul_mat_iq2_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + int nbl = n / QK_K; + float32x4_t acc[nrc_y] = {}; + int32x4_t isum[nrc_y] = {}; + int8x16_t qx[8]; + SignHelper sh; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto iq2 = (const block_iq2_xxs_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d)); + auto qs = iq2[ibl].qs; + for (int ib = 0; ib < QK_K/32; ++ib) { + auto sas = vld1q_u8(iq2[ibl].sas + 16*ib); + auto scale_bits = vandq_u8(sas, vdupq_n_u8(1)); + auto scales = ggml_vdotq_s32(vdupq_n_s32(1), scale_bits, vreinterpretq_s8_u32(vdupq_n_u32(0x10080402))); + auto signs128 = vandq_u8(sas, vdupq_n_u8(254)); + signs128 = veorq_u8(signs128, vshrq_n_u8(signs128, 1)); + sh.init(); + for (int i = 0; i < 8; ++i) { + qx[i] = vreinterpretq_s8_u64(uint64x2_t{iq2xxs_grid[qs[2*i+0]], iq2xxs_grid[qs[2*i+1]]}); + sh.apply_signs_1((uint8x16_t *)qx+i, signs128); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib); + auto sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), qx[1], y.val[1]); + auto sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), qx[3], y.val[1]); + auto sumi3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), qx[5], y.val[1]); + auto sumi4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), qx[7], y.val[1]); + auto sumi12 = vpaddq_s32(sumi1, sumi2); + auto sumi34 = vpaddq_s32(sumi3, sumi4); + auto sumi = vpaddq_s32(sumi12, sumi34); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + qs += 16; + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + isum[iy] = vdupq_n_s32(0); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vmulq_f32(vdupq_n_f32(0.125f), acc[iy])); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template +static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + int nbl = n / QK_K; + static const uint8_t k_shuff[16] = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}; + auto shuff = vld1q_u8(k_shuff); + float32x4_t acc[nrc_y] = {}; + int32x4_t isum[2*nrc_y] = {}; + int8x16_t qx[8]; + uint16x8x4_t scales16; + SignHelper sh; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto iq2 = (const block_iq2_xs_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d)); + auto qs = iq2[ibl].qs; + for (int is = 0; is < 2; ++is) { + auto scale_bits = vld1q_u8(iq2[ibl].scales + 16*is); + auto scales1 = vandq_u8(scale_bits, vdupq_n_u8(0xf)); + auto scales2 = vshrq_n_u8(scale_bits, 4); + scales1 = vorrq_u8(vshlq_n_u8(scales1, 1), vdupq_n_u8(1)); + scales2 = vorrq_u8(vshlq_n_u8(scales2, 1), vdupq_n_u8(1)); + auto s1 = vzip1q_u8(scales1, scales2); + auto s2 = vzip2q_u8(scales1, scales2); + scales16.val[0] = vmovl_u8(vget_low_u8 (s1)); + scales16.val[1] = vmovl_u8(vget_high_u8(s1)); + scales16.val[2] = vmovl_u8(vget_low_u8 (s2)); + scales16.val[3] = vmovl_u8(vget_high_u8(s2)); + for (int ib = 0; ib < QK_K/64; ++ib) { + auto v = vld1q_u8_x2((const uint8_t *)qs); + auto signs128 = vandq_u8(vqtbl2q_u8(v, shuff), vdupq_n_u8(254)); + signs128 = veorq_u8(signs128, vshrq_n_u8(signs128, 1)); + sh.init(); + for (int i = 0; i < 8; ++i) { + qx[i] = vreinterpretq_s8_u64(uint64x2_t{iq2xs_grid[qs[2*i+0] & 511], iq2xs_grid[qs[2*i+1] & 511]}); + sh.apply_signs_1((uint8x16_t *)qx+i, signs128); + } + auto s32_1 = vmovl_u16(vget_low_u16 (scales16.val[ib])); + auto s32_2 = vmovl_u16(vget_high_u16(scales16.val[ib])); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 128*is + 32*ib); + auto sumi1 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[1], y.val[1])); + auto sumi2 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[3], y.val[1])); + auto sumi3 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[5], y.val[1])); + auto sumi4 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[7], y.val[1])); + auto sumi12 = vpaddq_s32(sumi1, sumi2); // blocks 0,1,2,3 in rows 0,1 + auto sumi34 = vpaddq_s32(sumi3, sumi4); // blocks 4,5,6,7 in rows 2,3 + isum[2*iy+0] = vmlaq_s32(isum[2*iy+0], s32_1, sumi12); + isum[2*iy+1] = vmlaq_s32(isum[2*iy+1], s32_2, sumi34); + } + qs += 16; + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = vpaddq_s32(isum[2*iy+0], isum[2*iy+1]); + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(sumi)); + isum[2*iy] = isum[2*iy+1] = vdupq_n_s32(0); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vmulq_f32(vdupq_n_f32(0.125f), acc[iy])); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template +static void mul_mat_iq2_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + int nbl = n / QK_K; + float32x4_t acc[nrc_y] = {}; + int32x4_t isum[2*nrc_y] = {}; + int8x16_t qx[8]; + uint16x8x4_t scales16; + SignHelper sh; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto iq2 = (const block_iq2_s_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d)); + auto qs = iq2[ibl].qs; + auto qh = iq2[ibl].qh; + for (int is = 0; is < 2; ++is) { + auto scale_bits = vld1q_u8(iq2[ibl].scales + 16*is); + auto scales1 = vandq_u8(scale_bits, vdupq_n_u8(0xf)); + auto scales2 = vshrq_n_u8(scale_bits, 4); + scales1 = vorrq_u8(vshlq_n_u8(scales1, 1), vdupq_n_u8(1)); + scales2 = vorrq_u8(vshlq_n_u8(scales2, 1), vdupq_n_u8(1)); + auto s1 = vzip1q_u8(scales1, scales2); + auto s2 = vzip2q_u8(scales1, scales2); + scales16.val[0] = vmovl_u8(vget_low_u8 (s1)); + scales16.val[1] = vmovl_u8(vget_high_u8(s1)); + scales16.val[2] = vmovl_u8(vget_low_u8 (s2)); + scales16.val[3] = vmovl_u8(vget_high_u8(s2)); + for (int ib = 0; ib < QK_K/64; ++ib) { + auto signs128 = vld1q_u8(iq2[ibl].signs + 64*is + 16*ib); + sh.init(); + for (int i = 0; i < 4; ++i) { + qx[2*i+0] = vreinterpretq_s8_u64(uint64x2_t{iq2s_grid[qs[4*i+0] | ((qh[i] << 8) & 0x300)], iq2s_grid[qs[4*i+1] | ((qh[i] << 6) & 0x300)]}); + sh.apply_signs_1((uint8x16_t *)qx+2*i+0, signs128); + qx[2*i+1] = vreinterpretq_s8_u64(uint64x2_t{iq2s_grid[qs[4*i+2] | ((qh[i] << 4) & 0x300)], iq2s_grid[qs[4*i+3] | ((qh[i] << 2) & 0x300)]}); + sh.apply_signs_1((uint8x16_t *)qx+2*i+1, signs128); + } + qs += 16; qh += 4; + auto s32_1 = vmovl_u16(vget_low_u16 (scales16.val[ib])); + auto s32_2 = vmovl_u16(vget_high_u16(scales16.val[ib])); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 128*is + 32*ib); + auto sumi1 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[1], y.val[1])); + auto sumi2 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[3], y.val[1])); + auto sumi3 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[5], y.val[1])); + auto sumi4 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[7], y.val[1])); + auto sumi12 = vpaddq_s32(sumi1, sumi2); // blocks 0,1,2,3 in rows 0,1 + auto sumi34 = vpaddq_s32(sumi3, sumi4); // blocks 4,5,6,7 in rows 2,3 + isum[2*iy+0] = vmlaq_s32(isum[2*iy+0], s32_1, sumi12); + isum[2*iy+1] = vmlaq_s32(isum[2*iy+1], s32_2, sumi34); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = vpaddq_s32(isum[2*iy+0], isum[2*iy+1]); + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(sumi)); + isum[2*iy] = isum[2*iy+1] = vdupq_n_s32(0); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vmulq_f32(vdupq_n_f32(0.125f), acc[iy])); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template +static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + int nbl = n / QK_K; + float32x4_t acc[nrc_y] = {}; + int32x4_t isum[nrc_y] = {}; + int8x16_t qx[8]; + SignHelper sh; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto iq3 = (const block_iq3_xxs_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto d4 = vmulq_f32(vdupq_n_f32(0.25f), vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d))); + auto qs = iq3[ibl].qs; + for (int ib = 0; ib < QK_K/32; ++ib) { + auto sas = vld1q_u8(iq3[ibl].sas + 16*ib); + auto scale_bits = vandq_u8(sas, vdupq_n_u8(1)); + auto scales = ggml_vdotq_s32(vdupq_n_s32(1), scale_bits, vreinterpretq_s8_u32(vdupq_n_u32(0x10080402))); + auto signs128 = vandq_u8(sas, vdupq_n_u8(254)); + signs128 = veorq_u8(signs128, vshrq_n_u8(signs128, 1)); + sh.init(); + for (int i = 0; i < 8; ++i) { + qx[i] = vreinterpretq_s8_u32(uint32x4_t{iq3xxs_grid[qs[4*i+0]], iq3xxs_grid[qs[4*i+1]], iq3xxs_grid[qs[4*i+2]], iq3xxs_grid[qs[4*i+3]]}); + sh.apply_signs_1((uint8x16_t *)qx+i, signs128); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib); + auto sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), qx[1], y.val[1]); + auto sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), qx[3], y.val[1]); + auto sumi3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), qx[5], y.val[1]); + auto sumi4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), qx[7], y.val[1]); + auto sumi12 = vpaddq_s32(sumi1, sumi2); + auto sumi34 = vpaddq_s32(sumi3, sumi4); + auto sumi = vpaddq_s32(sumi12, sumi34); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + qs += 32; + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + isum[iy] = vdupq_n_s32(0); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template +static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + int nbl = n / QK_K; + float32x4_t acc[nrc_y] = {}; + int32x4_t isum[nrc_y] = {}; + int8x16_t qx[8]; + auto m1 = vdupq_n_u8(1); + auto shuff = vreinterpretq_u8_u32(uint32x4_t{0xffffff00, 0xffffff01, 0xffffff02, 0xffffff03}); + uint32_t stored_scales[8]; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto iq3 = (const block_iq3_s_r4 *)((const char *)vx + (ix+0)*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d)); + auto qs = iq3[ibl].qs; + auto qh = iq3[ibl].qh; + auto scale_bits = vld1q_u8(iq3[ibl].scales); + uint8x16x2_t scales8 = { vandq_u8(scale_bits, vdupq_n_u8(0xf)), vshrq_n_u8(scale_bits, 4) }; + scales8.val[0] = vorrq_u8(vshlq_n_u8(scales8.val[0], 1), m1); + scales8.val[1] = vorrq_u8(vshlq_n_u8(scales8.val[1], 1), m1); + vst1q_u8_x2((uint8_t *)stored_scales, scales8); + for (int ib = 0; ib < QK_K/32; ++ib) { + auto signs128 = vld1q_u8(iq3[ibl].signs+16*ib); + if constexpr (nrc_y == 1) { + auto qh32 = (const uint32_t *)qh; + auto idx_h = vreinterpretq_u16_u64(vshlq_u64(vreinterpretq_u64_u16(vmovl_u8(vreinterpret_u8_u32(vdup_n_u32(qh32[0])))), int64x2_t{8, 4})); + union { uint16x8_t vec; uint16_t val[8]; } hidx; + for (int i = 0; i < 4; ++i) { + auto idx_l = vmovl_u8(vld1_u8(qs)); + hidx.vec = vorrq_u16(idx_l, vandq_u16(idx_h, vdupq_n_u16(0x100))); idx_h = vshrq_n_u16(idx_h, 1); + qx[2*i+0] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[hidx.val[0]], iq3s_grid[hidx.val[1]], iq3s_grid[hidx.val[2]], iq3s_grid[hidx.val[3]]}); + auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(signs128, m1), m1), m1)); + qx[2*i+0] = vmulq_s8(qx[2*i+0], signs); + qx[2*i+1] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[hidx.val[4]], iq3s_grid[hidx.val[5]], iq3s_grid[hidx.val[6]], iq3s_grid[hidx.val[7]]}); + signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(vshrq_n_u8(signs128, 4), m1), m1), m1)); + qx[2*i+1] = vmulq_s8(qx[2*i+1], signs); + signs128 = vshrq_n_u8(signs128, 1); + qs += 8; + } + } else { + for (int i = 0; i < 4; ++i) { + qx[2*i+0] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[qs[0] | ((qh[0] << (8-i)) & 0x100)], iq3s_grid[qs[1] | ((qh[1] << (8-i)) & 0x100)], + iq3s_grid[qs[2] | ((qh[2] << (8-i)) & 0x100)], iq3s_grid[qs[3] | ((qh[3] << (8-i)) & 0x100)]}); + auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(signs128, m1), m1), m1)); + qx[2*i+0] = vmulq_s8(qx[2*i+0], signs); + + qx[2*i+1] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[qs[4] | ((qh[0] << (4-i)) & 0x100)], iq3s_grid[qs[5] | ((qh[1] << (4-i)) & 0x100)], + iq3s_grid[qs[6] | ((qh[2] << (4-i)) & 0x100)], iq3s_grid[qs[7] | ((qh[3] << (4-i)) & 0x100)]}); + signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(vshrq_n_u8(signs128, 4), m1), m1), m1)); + qx[2*i+1] = vmulq_s8(qx[2*i+1], signs); + + qs += 8; + signs128 = vshrq_n_u8(signs128, 1); + } + } + auto scales = vreinterpretq_s32_u8(vqtbl1q_u8(vreinterpretq_u8_u32(vdupq_n_u32(stored_scales[ib])), shuff)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + qh += 4; + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + isum[iy] = vdupq_n_s32(0); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template +inline void iq3_4_add_shift(int ibl, const Q8& q8, const int8x16x4_t& i8scales, uint8x16_t extra, + int32x4_t * isum) { + auto ms = vdupq_n_s8(k_shift); + int8x16_t s8_1, s8_2; + if constexpr (k_shift == 5) { + auto m1 = vdupq_n_u8(1); + s8_1 = vmulq_s8(i8scales.val[0], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2); + s8_2 = vmulq_s8(i8scales.val[1], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2); + } else { + if constexpr (k_shift == 4) { + s8_1 = vmulq_s8(i8scales.val[0], vandq_u8(ms, vshlq_n_u8(extra, 2))); + s8_2 = vmulq_s8(i8scales.val[1], vandq_u8(ms, extra)); + } else { + s8_1 = vmulq_s8(i8scales.val[0], vandq_u8(ms, vshlq_n_u8(extra, 1))); + s8_2 = vmulq_s8(i8scales.val[1], vandq_u8(ms, vshrq_n_u8(extra, 1))); + } + } + auto s16_1 = vmovl_s8(vget_low_s8 (s8_1)); + auto s16_2 = vmovl_s8(vget_high_s8(s8_1)); + auto s16_3 = vmovl_s8(vget_low_s8 (s8_2)); + auto s16_4 = vmovl_s8(vget_high_s8(s8_2)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto b8 = vld1_s16(q8.y[iy][ibl].bsums); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3); + b8 = vld1_s16(q8.y[iy][ibl].bsums+4); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3); + } + if constexpr (k_shift == 5) { + auto m1 = vdupq_n_u8(1); + s8_1 = vmulq_s8(i8scales.val[2], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2); + s8_2 = vmulq_s8(i8scales.val[3], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2); + } else { + if constexpr (k_shift == 4) { + s8_1 = vmulq_s8(i8scales.val[2], vandq_u8(ms, vshrq_n_u8(extra, 2))); + s8_2 = vmulq_s8(i8scales.val[3], vandq_u8(ms, vshrq_n_u8(extra, 4))); + } else { + s8_1 = vmulq_s8(i8scales.val[2], vandq_u8(ms, vshrq_n_u8(extra, 3))); + s8_2 = vmulq_s8(i8scales.val[3], vandq_u8(ms, vshrq_n_u8(extra, 5))); + } + } + s16_1 = vmovl_s8(vget_low_s8 (s8_1)); + s16_2 = vmovl_s8(vget_high_s8(s8_1)); + s16_3 = vmovl_s8(vget_low_s8 (s8_2)); + s16_4 = vmovl_s8(vget_high_s8(s8_2)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto b8 = vld1_s16(q8.y[iy][ibl].bsums+8); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3); + b8 = vld1_s16(q8.y[iy][ibl].bsums+12); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1); + isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2); + isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3); + } +} + +template +void mul_mat_iq2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto m4 = vdupq_n_u8(0xf); + auto m03 = vdupq_n_u8(0x03); + auto ms = vdupq_n_u8(4); + uint8x16x2_t shift_shuffle = { + vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}), + vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606}) + }; + auto values8 = vld1_s8(iq2nl_values); + auto values = vcombine_s8(values8, values8); + int nbl = n / QK_K; + int8x16_t qx[4]; + int8x16x4_t i8scales; + int16x8x4_t i16scales; + float32x4_t acc[nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq2_k_r4 * iq2 = (const block_iq2_k_r4 *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d)); + auto extra8 = vld1_u8(iq2[ibl].extra); + uint8x16_t extra; + if constexpr (nrc_y == 1) { + extra = vcombine_u8(extra8, vshr_n_u8(extra8,1)); + } else { + extra = vcombine_u8(extra8, extra8); + } + auto sl = vld1q_u8_x2(iq2[ibl].scales); + i8scales.val[0] = vaddq_s8(vandq_u8(sl.val[0], m4), vdupq_n_s8(-8)); + i8scales.val[1] = vaddq_s8(vandq_u8(sl.val[1], m4), vdupq_n_s8(-8)); + i8scales.val[2] = vaddq_s8(vshrq_n_u8(sl.val[0], 4), vdupq_n_s8(-8)); + i8scales.val[3] = vaddq_s8(vshrq_n_u8(sl.val[1], 4), vdupq_n_s8(-8)); + int32x4_t isum[nrc_y] = {}; + if constexpr (nrc_y == 1) { + iq3_4_add_shift(ibl, q8, i8scales, extra, isum); + } + for (int is = 0; is < 2; ++is) { + i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0])); + i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0])); + i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1])); + i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1])); + for (int ib = 0; ib < 4; ++ib) { + auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib])); + auto bits = vld1q_u8_x2(iq2[ibl].qs + 128*is + 32*ib); + qx[0] = vandq_u8( bits.val[0], m03); + qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m03); + qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m03); + qx[3] = vandq_u8(vshrq_n_u8(bits.val[0], 6), m03); + uint8x16_t shifts; + if constexpr (nrc_y == 1) { + qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows + qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7 + qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11 + qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15 + } else { + shifts = vandq_u8(ms, vshlq_n_u8(extra, 2)); + auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]); + extra = vshrq_n_u8(extra, 1); + qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows + qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7 + qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11 + qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15 + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + qx[0] = vandq_u8( bits.val[1], m03); + qx[1] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m03); + qx[2] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m03); + qx[3] = vandq_u8(vshrq_n_u8(bits.val[1], 6), m03); + if constexpr (nrc_y == 1) { + qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows + qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7 + qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11 + qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15 + } else { + auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]); + qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows + qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7 + qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11 + qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15 + } + scales = vmovl_s16(vget_high_s16(i16scales.val[ib])); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template +void mul_mat_iq3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto m4 = vdupq_n_u8(0xf); + auto ms = nrc_y == 1 ? vdupq_n_u8(4) : vdupq_n_u8(8); + auto m03 = vdupq_n_u8(0x03); + auto m04 = vdupq_n_u8(0x04); + uint8x16x2_t shift_shuffle = { + vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}), + vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606}) + }; + uint8x16x2_t smask = { vcombine_u8(vdup_n_u8(1), vdup_n_u8(2)), vcombine_u8(vdup_n_u8(4), vdup_n_u8(8)) }; + auto values = vld1q_s8(iq3nl_values); + int nbl = n / QK_K; + int8x16_t qx[4]; + int8x16x4_t i8scales; + int16x8x4_t i16scales; + float32x4_t acc[nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq3_k_r4 * iq3 = (const block_iq3_k_r4 *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d)); + auto extra8 = vld1_u8(iq3[ibl].extra); + uint8x16_t extra; + if constexpr (nrc_y == 1) { + extra = vcombine_u8(extra8, vshr_n_u8(extra8,1)); + } else { + extra = vcombine_u8(extra8, extra8); + } + auto sl = vld1q_u8_x2(iq3[ibl].scales_l); + auto sh8 = vld1_u8(iq3[ibl].scales_h); + auto sh = vcombine_u8(sh8, sh8); + i8scales.val[0] = vaddq_s8(vshlq_n_u8(vandq_u8(sl.val[0], m4), 1), vdupq_n_s8(1)); + i8scales.val[1] = vaddq_s8(vshlq_n_u8(vandq_u8(sl.val[1], m4), 1), vdupq_n_s8(1)); + i8scales.val[2] = vaddq_s8(vshlq_n_u8(vshrq_n_u8(sl.val[0], 4), 1), vdupq_n_s8(1)); + i8scales.val[3] = vaddq_s8(vshlq_n_u8(vshrq_n_u8(sl.val[1], 4), 1), vdupq_n_s8(1)); + i8scales.val[0] = vmulq_s8(i8scales.val[0], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[0]), smask.val[0]), vdupq_n_u8(1))); + i8scales.val[1] = vmulq_s8(i8scales.val[1], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[1]), smask.val[1]), vdupq_n_u8(1))); + sh = vshrq_n_u8(sh, 4); + i8scales.val[2] = vmulq_s8(i8scales.val[2], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[0]), smask.val[0]), vdupq_n_u8(1))); + i8scales.val[3] = vmulq_s8(i8scales.val[3], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[1]), smask.val[1]), vdupq_n_u8(1))); + int32x4_t isum[nrc_y] = {}; + if constexpr (nrc_y == 1) { + iq3_4_add_shift(ibl, q8, i8scales, extra, isum); + } + for (int is = 0; is < 2; ++is) { + i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0])); + i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0])); + i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1])); + i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1])); + for (int ib = 0; ib < 4; ++ib) { + auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib])); + auto lbits = vld1q_u8_x2(iq3[ibl].qs + 128*is + 32*ib); + auto hbits = vld1q_u8(iq3[ibl].qh + 64*is + 16*ib); + qx[0] = vorrq_u8(vandq_u8( lbits.val[0], m03), vandq_u8(m04, vshlq_n_u8(hbits, 2))); + qx[1] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 2), m03), vandq_u8(m04, vshlq_n_u8(hbits, 1))); + qx[2] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 4), m03), vandq_u8(m04, hbits)); + qx[3] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 6), m03), vandq_u8(m04, vshrq_n_u8(hbits, 1))); + uint8x16_t shifts; + if constexpr (nrc_y == 1) { + qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows + qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7 + qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11 + qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15 + } else { + shifts = vandq_u8(ms, vshlq_n_u8(extra, 3)); + auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]); + extra = vshrq_n_u8(extra, 1); + qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows + qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7 + qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11 + qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15 + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + qx[0] = vorrq_u8(vandq_u8( lbits.val[1], m03), vandq_u8(m04, vshrq_n_u8(hbits, 2))); + qx[1] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 2), m03), vandq_u8(m04, vshrq_n_u8(hbits, 3))); + qx[2] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 4), m03), vandq_u8(m04, vshrq_n_u8(hbits, 4))); + qx[3] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 6), m03), vandq_u8(m04, vshrq_n_u8(hbits, 5))); + if constexpr (nrc_y == 1) { + qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows + qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7 + qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11 + qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15 + } else { + auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]); + qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows + qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7 + qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11 + qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15 + } + scales = vmovl_s16(vget_high_s16(i16scales.val[ib])); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template +void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto m4 = vdupq_n_u8(0xf); + auto m3 = vdupq_n_u8(0x30); + auto ms = vdupq_n_u8(4); + auto m32 = vdupq_n_s8(-32); + uint8x16x2_t shift_shuffle = { + vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}), + vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606}) + }; + auto values = vld1q_s8(iq4k_values); + int nbl = n / QK_K; + int8x16_t qx[4]; + int8x16x4_t i8scales; + int16x8x4_t i16scales; + float32x4_t acc[nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq4_k_r4 * iq4 = (const block_iq4_k_r4 *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ibl].d)); + auto extra8 = vld1_u8(iq4[ibl].extra); + uint8x16_t extra; + if constexpr (nrc_y == 1) { + extra = vcombine_u8(extra8, vshr_n_u8(extra8,1)); + } else { + extra = vcombine_u8(extra8, extra8); + } + auto sl = vld1q_u8_x2(iq4[ibl].scales_l); + auto sh = vld1q_u8(iq4[ibl].scales_h); + i8scales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[0], m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32); + i8scales.val[1] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[1], m4), vandq_u8(vshlq_n_u8(sh, 2), m3)), m32); + i8scales.val[2] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m3)), m32); + i8scales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32); + int32x4_t isum[nrc_y] = {}; + if constexpr (nrc_y == 1) { + iq3_4_add_shift(ibl, q8, i8scales, extra, isum); + } + for (int is = 0; is < 2; ++is) { + i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0])); + i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0])); + i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1])); + i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1])); + for (int ib = 0; ib < 4; ++ib) { + auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib); + uint8x16_t shifts; + if constexpr (nrc_y == 1) { + qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows + qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7 + qx[2] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11 + qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15 + } else { + shifts = vandq_u8(ms, vshlq_n_u8(extra, 2)); + auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]); + extra = vshrq_n_u8(extra, 1); + qx[0] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[0], m4))); // 0...3 from the 4 rows + qx[1] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[2], m4))); // 4...7 + qx[2] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4))); // 8..11 + qx[3] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4))); // 12..15 + } + auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib])); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + if constexpr (nrc_y == 1) { + qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19 + qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23 + qx[2] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27 + qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31 + } else { + auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]); + qx[0] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[1], m4))); // 16..19 + qx[1] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[3], m4))); // 20..23 + qx[2] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4))); // 24..27 + qx[3] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4))); // 28..31 + } + scales = vmovl_s16(vget_high_s16(i16scales.val[ib])); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template +void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto m4 = vdupq_n_u8(0xf); + auto m3 = vdupq_n_u8(0x30); + auto ms = vdupq_n_u8(2); + auto m32 = vdupq_n_s8(-32); + auto m10 = vdupq_n_u8(0x10); + uint8x16x2_t shift_shuffle = { + vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}), + vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606}) + }; + auto values = vld1q_s8_x2(iq5nl_values); + int nbl = n / QK_K; + int8x16_t qx[4]; + int8x16x4_t i8scales; + int16x8x4_t i16scales; + float32x4_t acc[nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq5_k_r4 * iq5 = (const block_iq5_k_r4 *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ibl].d)); + auto extra8 = vld1_u8(iq5[ibl].extra); + uint8x16_t extra; + if constexpr (nrc_y == 1) { + extra = vcombine_u8(extra8, vshr_n_u8(extra8,1)); + } else { + extra = vcombine_u8(extra8, extra8); + } + auto sl = vld1q_u8_x2(iq5[ibl].scales_l); + auto sh = vld1q_u8(iq5[ibl].scales_h); + i8scales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[0], m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32); + i8scales.val[1] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[1], m4), vandq_u8(vshlq_n_u8(sh, 2), m3)), m32); + i8scales.val[2] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m3)), m32); + i8scales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32); + int32x4_t isum[nrc_y] = {}; + if constexpr (nrc_y == 1) { + iq3_4_add_shift(ibl, q8, i8scales, extra, isum); + } + for (int is = 0; is < 2; ++is) { + i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0])); + i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0])); + i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1])); + i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1])); + for (int ib = 0; ib < 4; ++ib) { + auto lbits = vld1q_u8_x4(iq5[ibl].qs + 256*is + 64*ib); + auto hbits = vld1q_u8(iq5[ibl].qh + 64*is + 16*ib); + qx[0] = vorrq_u8(vandq_u8(lbits.val[0], m4), vandq_u8(m10, vshlq_n_u8(hbits, 4))); // aligns with 1st half of qx[0] in AVX2 + qx[1] = vorrq_u8(vandq_u8(lbits.val[2], m4), vandq_u8(m10, hbits)); // aligns with 1st half of qx[1] in AVX2 + qx[2] = vorrq_u8(vshrq_n_u8(lbits.val[0], 4), vandq_u8(m10, vshlq_n_u8(hbits, 3))); // aligns with 1st half of qx[2] in AVX2 + qx[3] = vorrq_u8(vshrq_n_u8(lbits.val[2], 4), vandq_u8(m10, vshrq_n_u8(hbits, 1))); // aligns with 1st half of qx[3] in AVX2 + uint8x16_t shifts; + if constexpr (nrc_y == 1) { + qx[0] = vqtbl2q_s8(values, qx[0]); // 0...3 from the 4 rows + qx[1] = vqtbl2q_s8(values, qx[1]); // 4...7 + qx[2] = vqtbl2q_s8(values, qx[2]); // 8..11 + qx[3] = vqtbl2q_s8(values, qx[3]); // 12..15 + } else { + shifts = vandq_u8(ms, vshlq_n_u8(extra, 1)); + auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]); + extra = vshrq_n_u8(extra, 1); + qx[0] = vaddq_s8(shift, vqtbl2q_s8(values, qx[0])); // 0...3 from the 4 rows + qx[1] = vaddq_s8(shift, vqtbl2q_s8(values, qx[1])); // 4...7 + qx[2] = vaddq_s8(shift, vqtbl2q_s8(values, qx[2])); // 8..11 + qx[3] = vaddq_s8(shift, vqtbl2q_s8(values, qx[3])); // 12..15 + } + auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib])); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + qx[0] = vorrq_u8(vandq_u8(lbits.val[1], m4), vandq_u8(m10, vshlq_n_u8(hbits, 2))); // aligns with 2nd half of qx[0] in AVX2 + qx[1] = vorrq_u8(vandq_u8(lbits.val[3], m4), vandq_u8(m10, vshrq_n_u8(hbits, 2))); // aligns with 2nd half of qx[1] in AVX2 + qx[2] = vorrq_u8(vshrq_n_u8(lbits.val[1], 4), vandq_u8(m10, vshlq_n_u8(hbits, 1))); // aligns with 2nd half of qx[2] in AVX2 + qx[3] = vorrq_u8(vshrq_n_u8(lbits.val[3], 4), vandq_u8(m10, vshrq_n_u8(hbits, 3))); // aligns with 2nd half of qx[3] in AVX2 + if constexpr (nrc_y == 1) { + qx[0] = vqtbl2q_s8(values, qx[0]); // 0...3 from the 4 rows + qx[1] = vqtbl2q_s8(values, qx[1]); // 4...7 + qx[2] = vqtbl2q_s8(values, qx[2]); // 8..11 + qx[3] = vqtbl2q_s8(values, qx[3]); // 12..15 + } else { + auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]); + qx[0] = vaddq_s8(shift, vqtbl2q_s8(values, qx[0])); // 0...3 from the 4 rows + qx[1] = vaddq_s8(shift, vqtbl2q_s8(values, qx[1])); // 4...7 + qx[2] = vaddq_s8(shift, vqtbl2q_s8(values, qx[2])); // 8..11 + qx[3] = vaddq_s8(shift, vqtbl2q_s8(values, qx[3])); // 12..15 + } + scales = vmovl_s16(vget_high_s16(i16scales.val[ib])); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +IQK_ALWAYS_INLINE void prepare_q4_k_quants(const uint8x16_t& m4, const uint8x16x4_t& bits, int8x16_t * qx) { + qx[0] = vandq_u8(bits.val[0], m4); // 0...3 from the 4 rows + qx[1] = vandq_u8(bits.val[1], m4); // 16..19 + qx[2] = vandq_u8(bits.val[2], m4); // 4...7 + qx[3] = vandq_u8(bits.val[3], m4); // 20..23 + qx[4] = vshrq_n_u8(bits.val[0], 4); // 8..11 + qx[5] = vshrq_n_u8(bits.val[1], 4); // 24..27 + qx[6] = vshrq_n_u8(bits.val[2], 4); // 12..15 + qx[7] = vshrq_n_u8(bits.val[3], 4); // 28..31 +} + +template +void mul_mat_q2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto mf = vdupq_n_u8(0x0f); + auto m03 = vdupq_n_u8(0x03); + int nbl = n / QK_K; + int8x16_t qx[4]; + float32x4_t acc[nrc_y] = {}; + int16x8x4_t i16scales; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_q2_k_r4 * iq2 = (const block_q2_k_r4 *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { + int32x4_t isum[nrc_y] = {}; + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d)); + auto m4 = vmulq_f32(vdupq_n_f32(-1.f), vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d+4))); + for (int is = 0; is < 2; ++is) { + auto sl = vld1q_u8_x2(iq2[ibl].scales + 32*is); + auto m = vshrq_n_u8(sl.val[0], 4); + i16scales.val[0] = vmovl_u8(vget_low_u8 (m)); + i16scales.val[1] = vmovl_u8(vget_high_u8(m)); + m = vshrq_n_u8(sl.val[1], 4); + i16scales.val[2] = vmovl_u8(vget_low_u8 (m)); + i16scales.val[3] = vmovl_u8(vget_high_u8(m)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = vdupq_n_s32(0); + auto bsums = vld1q_s16(q8.y[iy][ibl].bsums + 8*is); + auto b8 = vget_low_s16(bsums); + //auto bsums = q8.load_bsums(iy, ibl); + //auto b8 = vget_low_s16(bsums.val[0]); + sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[0]), b8, 0); + sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[0]), b8, 1); + sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[1]), b8, 2); + sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[1]), b8, 3); + b8 = vget_high_s16(bsums); + sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[2]), b8, 0); + sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[2]), b8, 1); + sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[3]), b8, 2); + sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[3]), b8, 3); + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(m4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(sumi)); + } + m = vandq_u8(sl.val[0], mf); + i16scales.val[0] = vmovl_u8(vget_low_u8 (m)); + i16scales.val[1] = vmovl_u8(vget_high_u8(m)); + m = vandq_u8(sl.val[1], mf); + i16scales.val[2] = vmovl_u8(vget_low_u8 (m)); + i16scales.val[3] = vmovl_u8(vget_high_u8(m)); + for (int ib = 0; ib < 4; ++ib) { + auto bits = vld1q_u8_x2(iq2[ibl].qs + 128*is + 32*ib); + auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib])); + qx[0] = vreinterpretq_s8_u8(vandq_u8( bits.val[0], m03)); + qx[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[0], 2), m03)); + qx[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[0], 4), m03)); + qx[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[0], 6), m03)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + scales = vmovl_s16(vget_high_s16(i16scales.val[ib])); + qx[0] = vreinterpretq_s8_u8(vandq_u8( bits.val[1], m03)); + qx[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[1], 2), m03)); + qx[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[1], 4), m03)); + qx[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[1], 6), m03)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template +void mul_mat_q3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto mf = vdupq_n_u8(0x0f); + auto m30 = vdupq_n_u8(0x30); + auto m32 = vdupq_n_s8(-32); + auto m03 = vdupq_n_u8(0x03); + auto m04 = vdupq_n_u8(0x04); + int nbl = n / QK_K; + int8x16_t qx[4]; + float32x4_t acc[nrc_y] = {}; + int8x16x4_t i8scales; + int16x8x4_t i16scales; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_q3_k_r4 * iq3 = (const block_q3_k_r4 *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { + int32x4_t isum[nrc_y] = {}; + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d)); + auto sl = vld1q_u8_x2(iq3[ibl].scales_l); + auto sh = vld1q_u8(iq3[ibl].scales_h); + i8scales.val[0] = vaddq_s8(m32, vorrq_u8(vandq_u8(sl.val[0], mf), vandq_u8(vshlq_n_u8(sh, 4), m30))); + i8scales.val[1] = vaddq_s8(m32, vorrq_u8(vandq_u8(sl.val[1], mf), vandq_u8(vshlq_n_u8(sh, 2), m30))); + i8scales.val[2] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m30))); + i8scales.val[3] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m30))); + for (int is = 0; is < 2; ++is) { + i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0])); + i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0])); + i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1])); + i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1])); + for (int ib = 0; ib < 4; ++ib) { + auto lbits = vld1q_u8_x2(iq3[ibl].qs + 128*is + 32*ib); + auto hbits = vld1q_u8(iq3[ibl].qh + 64*is + 16*ib); + hbits = veorq_u8(hbits, vdupq_n_u8(0xff)); + auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib])); + qx[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8( lbits.val[0], m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshlq_n_u8(hbits, 2)))); + qx[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 2), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshlq_n_u8(hbits, 1)))); + qx[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 4), m03)), vreinterpretq_s8_u8(vandq_u8(m04, hbits))); + qx[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 6), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 1)))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + scales = vmovl_s16(vget_high_s16(i16scales.val[ib])); + qx[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8( lbits.val[1], m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 2)))); + qx[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 2), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 3)))); + qx[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 4), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 4)))); + qx[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 6), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 5)))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template +void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto mf = vdupq_n_u8(0xf); + auto m3 = vdupq_n_u8(0x30); + int nbl = n / QK_K; + int8x16_t qx[8]; + int8x16x2_t iscales; + int32x4x4_t scales; + float32x4_t acc[nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_q4_k_r4 * iq4 = (const block_q4_k_r4 *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ibl].d)); + auto m4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ibl].d+4)); + m4 = vmulq_f32(m4, vdupq_n_f32(-1.f)); + auto sl = vld1q_u8_x2(iq4[ibl].scales_l); + auto sh = vld1q_u8(iq4[ibl].scales_h); + iscales.val[0] = vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(vshlq_n_u8(sh, 2), m3)); + iscales.val[1] = vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)); + for (int is = 0; is < 2; ++is) { + auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is])); + auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is])); + float32x4x4_t fscales; + fscales.val[0] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_1)))); + fscales.val[1] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_1)))); + fscales.val[2] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_2)))); + fscales.val[3] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_2)))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto m8 = vld1q_f32((const float *)q8.y[iy][ibl].bsums + 4*is); + acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[0], m8, 0); + acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[1], m8, 1); + acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[2], m8, 2); + acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[3], m8, 3); + } + } + iscales.val[0] = vorrq_u8(vandq_u8(sl.val[0], mf), vandq_u8(vshlq_n_u8(sh, 4), m3)); + iscales.val[1] = vorrq_u8(vandq_u8(sl.val[1], mf), vandq_u8(sh, m3)); + int32x4_t isum[nrc_y] = {}; + for (int is = 0; is < 2; ++is) { + auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is])); + auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is])); + scales.val[0] = vmovl_s16(vget_low_s16(iscales16_1)); + scales.val[1] = vmovl_s16(vget_high_s16(iscales16_1)); + scales.val[2] = vmovl_s16(vget_low_s16(iscales16_2)); + scales.val[3] = vmovl_s16(vget_high_s16(iscales16_2)); + for (int ib = 0; ib < 4; ++ib) { + auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib); + prepare_q4_k_quants(mf, bits, qx); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template +void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto mf = vdupq_n_u8(0xf); + auto m30 = vdupq_n_u8(0x30); + auto m10 = vdupq_n_u8(0x10); + int nbl = n / QK_K; + int8x16_t qx[8]; + int8x16x2_t iscales; + int32x4x4_t scales; + float32x4_t acc[nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_q5_k_r4 * iq5 = (const block_q5_k_r4 *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ibl].d)); + auto m4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ibl].d+4)); + m4 = vmulq_f32(m4, vdupq_n_f32(-1.f)); + auto sl = vld1q_u8_x2(iq5[ibl].scales_l); + auto sh = vld1q_u8(iq5[ibl].scales_h); + iscales.val[0] = vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(vshlq_n_u8(sh, 2), m30)); + iscales.val[1] = vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m30)); + for (int is = 0; is < 2; ++is) { + auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is])); + auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is])); + float32x4x4_t fscales; + fscales.val[0] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_1)))); + fscales.val[1] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_1)))); + fscales.val[2] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_2)))); + fscales.val[3] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_2)))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto m8 = vld1q_f32((const float *)q8.y[iy][ibl].bsums + 4*is); + acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[0], m8, 0); + acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[1], m8, 1); + acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[2], m8, 2); + acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[3], m8, 3); + } + } + iscales.val[0] = vorrq_u8(vandq_u8(sl.val[0], mf), vandq_u8(vshlq_n_u8(sh, 4), m30)); + iscales.val[1] = vorrq_u8(vandq_u8(sl.val[1], mf), vandq_u8(sh, m30)); + int32x4_t isum[nrc_y] = {}; + for (int is = 0; is < 2; ++is) { + auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is])); + auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is])); + scales.val[0] = vmovl_s16(vget_low_s16(iscales16_1)); + scales.val[1] = vmovl_s16(vget_high_s16(iscales16_1)); + scales.val[2] = vmovl_s16(vget_low_s16(iscales16_2)); + scales.val[3] = vmovl_s16(vget_high_s16(iscales16_2)); + for (int ib = 0; ib < 4; ++ib) { + auto lbits = vld1q_u8_x4(iq5[ibl].qs + 256*is + 64*ib); + auto hbits2 = vld1q_u8(iq5[ibl].qh + 64*is + 16*ib); + auto hbits1 = vshlq_n_u8(hbits2, 4); + prepare_q4_k_quants(mf, lbits, qx); + qx[0] = vorrq_u8(qx[0], vandq_u8(m10, hbits1)); + qx[1] = vorrq_u8(qx[1], vandq_u8(m10, hbits2)); + qx[2] = vorrq_u8(qx[2], vandq_u8(m10, vshrq_n_u8(hbits1, 2))); + qx[3] = vorrq_u8(qx[3], vandq_u8(m10, vshrq_n_u8(hbits2, 2))); + qx[4] = vorrq_u8(qx[4], vandq_u8(m10, vshrq_n_u8(hbits1, 1))); + qx[5] = vorrq_u8(qx[5], vandq_u8(m10, vshrq_n_u8(hbits2, 1))); + qx[6] = vorrq_u8(qx[6], vandq_u8(m10, vshrq_n_u8(hbits1, 3))); + qx[7] = vorrq_u8(qx[7], vandq_u8(m10, vshrq_n_u8(hbits2, 3))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template +void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto mf = vdupq_n_u8(0x0f); + auto m3 = vdupq_n_u8(0x30); + auto m32 = vdupq_n_s8(-32); + int nbl = n / QK_K; + int8x16_t qx[4]; + float32x4_t acc[nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_q6_k_r4 * iq6 = (const block_q6_k_r4 *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { + auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[ibl].d)); + int32x4_t isum[nrc_y] = {}; + for (int is = 0; is < 2; ++is) { + for (int ib = 0; ib < 4; ++ib) { + auto lbits = vld1q_u8_x4(iq6[ibl].ql + 256*is + 64*ib); + auto hbits = vld1q_u8(iq6[ibl].qh + 128*is + 32*ib); + auto iscales = vmovl_s8(vld1_s8(iq6[ibl].scales + 32*is + 8*ib)); + auto scales = vmovl_s16(vget_low_s16(iscales)); + qx[0] = vaddq_s8(m32, vorrq_u8(vandq_u8 (lbits.val[0], mf), vandq_u8(m3, vshlq_n_u8(hbits, 4)))); + qx[1] = vaddq_s8(m32, vorrq_u8(vandq_u8 (lbits.val[2], mf), vandq_u8(m3, hbits))); + qx[2] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(lbits.val[0], 4), vandq_u8(m3, vshlq_n_u8(hbits, 2)))); + qx[3] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(lbits.val[2], 4), vandq_u8(m3, vshrq_n_u8(hbits, 2)))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + scales = vmovl_s16(vget_high_s16(iscales)); + hbits = vld1q_u8(iq6[ibl].qh + 128*is + 32*ib + 16); + qx[0] = vaddq_s8(m32, vorrq_u8(vandq_u8 (lbits.val[1], mf), vandq_u8(m3, vshlq_n_u8(hbits, 4)))); + qx[1] = vaddq_s8(m32, vorrq_u8(vandq_u8 (lbits.val[3], mf), vandq_u8(m3, hbits))); + qx[2] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(lbits.val[1], 4), vandq_u8(m3, vshlq_n_u8(hbits, 2)))); + qx[3] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(lbits.val[3], 4), vandq_u8(m3, vshrq_n_u8(hbits, 2)))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16); + auto sumi = interleaved_dotq(qx, y); + isum[iy] = vmlaq_s32(isum[iy], scales, sumi); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy])); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template +void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + Q8 q8(info); + int nbl = n / QK_K; + float32x4_t acc[2*nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q8_k_r8 * iq8 = (const block_q8_k_r8 *)((const char *)vx + ix*bx); + for (int ibl = 0; ibl < nbl; ++ibl) { + auto d4l = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[ibl].d+0)); + auto d4h = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[ibl].d+4)); + int32x4_t isum[2*nrc_y] = {}; + for (int ib = 0; ib < QK_K/16; ++ib) { + auto q1 = vld1q_s8_x4(iq8[ibl].qs + 128*ib + 0); + auto q2 = vld1q_s8_x4(iq8[ibl].qs + 128*ib + 64); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8(q8.y[iy][ibl].qs+16*ib); + isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q1.val[0], y, 0); + isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q1.val[1], y, 0); + isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q1.val[2], y, 1); + isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q1.val[3], y, 1); + isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q2.val[0], y, 2); + isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q2.val[1], y, 2); + isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q2.val[2], y, 3); + isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q2.val[3], y, 3); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto d8 = vdupq_n_f32(q8.scale(iy, ibl)); + const float * bsum = (const float *)q8.y[iy][ibl].bsums; + auto m8 = vdupq_n_f32(-128.f*bsum[0]); + acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(d4l, d8), vcvtq_f32_s32(isum[2*iy+0])); + acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(d4h, d8), vcvtq_f32_s32(isum[2*iy+1])); + acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], d4l, m8); + acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], d4l, m8); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix+0, iy, acc[2*iy+0]); + info.store(ix+4, iy, acc[2*iy+1]); + acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f); + } + } +} + +void mul_mat_iq4_nl_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8<1, block_q8_0_x4> q8(info); + auto m4 = vdupq_n_u8(0xf); + auto values = vld1q_s8(iq4k_values); + int nb = n / QK4_NL; + GGML_ASSERT(nb%4 == 0); + int8x16_t qx[8]; + for (int ix = 0; ix < nrc_x; ix += 4) { + auto acc = vdupq_n_f32(0.f); + const block_iq4_nl_r4 * iq4 = (const block_iq4_nl_r4 *)((const char *)vx + ix*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + auto y1 = vld1q_s8_x4(q8.y[0][ib4].qs); + auto y2 = vld1q_s8_x4(q8.y[0][ib4].qs+64); + for (int k = 0; k < 4; ++k) { + auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d)); + auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[0][ib4].d[k]))); + auto sumi = vdupq_n_s32(0); + const auto yval = k < 2 ? y1.val + 2*k : y2.val + 2*(k-2); + auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs); + qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows + qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19 + sumi = vdotq_laneq_s32(sumi, qx[0], yval[0], 0); + sumi = vdotq_laneq_s32(sumi, qx[1], yval[1], 0); + qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7 + qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23 + sumi = vdotq_laneq_s32(sumi, qx[2], yval[0], 1); + sumi = vdotq_laneq_s32(sumi, qx[3], yval[1], 1); + qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11 + qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27 + sumi = vdotq_laneq_s32(sumi, qx[4], yval[0], 2); + sumi = vdotq_laneq_s32(sumi, qx[5], yval[1], 2); + qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15 + qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31 + sumi = vdotq_laneq_s32(sumi, qx[6], yval[0], 3); + sumi = vdotq_laneq_s32(sumi, qx[7], yval[1], 3); + acc = vfmaq_f32(acc, d4d8, vcvtq_f32_s32(sumi)); + } + } + info.store(ix, 0, acc); + } +} + +template +void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + Dequantizer deq(vx, bx); + int nb = n / QK4_NL; + GGML_ASSERT(nb%4 == 0); + int8x16_t qx[8]; + float d8[4*nrc_y]; + float32x4_t acc[nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + deq.new_row(ix); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d))); + } + for (int k = 0; k < 4; ++k) { + auto scales = deq.prepare(ib4, k, qx); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k); + auto sumi = interleaved_dotq(qx, y); + auto d4d8 = vmulq_f32(scales, vdupq_n_f32(d8[4*iy+k])); + acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi)); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, deq.result(acc[iy])); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +struct IQ4_NL_R4_Dequantizer { + IQ4_NL_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx), values(vld1q_s8(iq4k_values)) {} + inline void new_row(int ix) { iq4 = (const block_iq4_nl_r4 *)(cx + ix*bx); } + inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const { + auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d)); + auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs); + prepare_iq4_nl_quants(values, m4, bits, qx); + return scales; + } + inline float32x4_t result(float32x4_t acc) const { + return acc; + } + + const char * cx; + const size_t bx; + const block_iq4_nl_r4 * iq4; + const uint8x16_t m4 = vdupq_n_u8(0x0f); + const int8x16_t values; +}; + +struct Q4_0_R4_Dequantizer { + Q4_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {} + inline void new_row(int ix) { iq4 = (const block_iq4_nl_r4 *)(cx + ix*bx); } + inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const { + auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d)); + auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs); + for (int j = 0; j < 4; ++j) bits.val[j] = veorq_u8(m88, bits.val[j]); + qx[0] = vshlq_n_u8(bits.val[0], 4); // 0...3 from the 4 rows + qx[1] = vshlq_n_u8(bits.val[1], 4); // 16..19 + qx[2] = vshlq_n_u8(bits.val[2], 4); // 4...7 + qx[3] = vshlq_n_u8(bits.val[3], 4); // 20..23 + qx[4] = vandq_u8(bits.val[0], m4); // 8..11 + qx[5] = vandq_u8(bits.val[1], m4); // 24..27 + qx[6] = vandq_u8(bits.val[2], m4); // 12..15 + qx[7] = vandq_u8(bits.val[3], m4); // 28..31 + return scales; + } + inline float32x4_t result(float32x4_t acc) const { + return vmulq_f32(norm, acc); + } + + const char * cx; + const size_t bx; + const block_iq4_nl_r4 * iq4; + const uint8x16_t m4 = vdupq_n_u8(0xf0); + const uint8x16_t m88 = vdupq_n_u8(0x88); + const float32x4_t norm = vdupq_n_f32(1.f/16); +}; + +struct Q5_0_R4_Dequantizer { + Q5_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {} + inline void new_row(int ix) { iq5 = (const block_q5_0_r4 *)(cx + ix*bx); } + inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const { + auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[4*ib4+k].d)); + auto lbits = vld1q_u8_x4(iq5[4*ib4+k].qs); + auto hbits = vld1q_u8(iq5[4*ib4+k].qh); + qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits, 4), m5), m16); // 0...3 + qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits, 3), m5), m16); // 16..19 + qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits, 2), m5), m16); // 4...7 + qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits, 1), m5), m16); // 20..23 + qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits, m5), m16); // 8..11 + qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(vshrq_n_u8(hbits, 1), m5), m16); // 24..27 + qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits, 2), m5), m16); // 12..15 + qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits, 3), m5), m16); // 28..31 + return scales; + } + inline float32x4_t result(float32x4_t acc) const { + return acc; + } + + const char * cx; + const size_t bx; + const block_q5_0_r4 * iq5; + const uint8x16_t m4 = vdupq_n_u8(0x0f); + const uint8x16_t m5 = vdupq_n_u8(0x10); + const int8x16_t m16 = vdupq_n_s8(-16); +}; + +struct Q6_0_R4_Dequantizer { + Q6_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {} + inline void new_row(int ix) { iq6 = (const block_q6_0_r4 *)(cx + ix*bx); } + inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const { + auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[4*ib4+k].d)); + auto lbits = vld1q_u8_x4(iq6[4*ib4+k].qs); + auto hbits = vld1q_u8_x2(iq6[4*ib4+k].qh); + qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 4), m6), m32); // 0...3 + qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 4), m6), m32); // 16..19 + qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 2), m6), m32); // 4...7 + qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 2), m6), m32); // 20..23 + qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits.val[0], m6), m32); // 8..11 + qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(hbits.val[1], m6), m32); // 24..27 + qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits.val[0], 2), m6), m32); // 12..15 + qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits.val[1], 2), m6), m32); // 28..31 + return scales; + } + inline float32x4_t result(float32x4_t acc) const { + return acc; + } + + const char * cx; + const size_t bx; + const block_q6_0_r4 * iq6; + const uint8x16_t m4 = vdupq_n_u8(0x0f); + const uint8x16_t m6 = vdupq_n_u8(0x30); + const int8x16_t m32 = vdupq_n_s8(-32); +}; + +template +void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + int nb = n / QK8_0; + GGML_ASSERT(nb%4 == 0); + float32x4_t acc[nrc_y] = {}; + float d8[4*nrc_y]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d))); + } + for (int k = 0; k < 4; ++k) { + auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[4*ib4+k].d)); + auto qx1 = vld1q_s8_x4(iq8[4*ib4+k].qs); + auto qx2 = vld1q_s8_x4(iq8[4*ib4+k].qs+64); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k); + auto sumi = vdupq_n_s32(0); + sumi = vdotq_laneq_s32(sumi, qx1.val[0], y.val[0], 0); + sumi = vdotq_laneq_s32(sumi, qx1.val[1], y.val[1], 0); + sumi = vdotq_laneq_s32(sumi, qx1.val[2], y.val[0], 1); + sumi = vdotq_laneq_s32(sumi, qx1.val[3], y.val[1], 1); + sumi = vdotq_laneq_s32(sumi, qx2.val[0], y.val[0], 2); + sumi = vdotq_laneq_s32(sumi, qx2.val[1], y.val[1], 2); + sumi = vdotq_laneq_s32(sumi, qx2.val[2], y.val[0], 3); + sumi = vdotq_laneq_s32(sumi, qx2.val[3], y.val[1], 3); + auto d4d8 = vmulq_f32(scales, vdupq_n_f32(d8[4*iy+k])); + acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi)); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = vdupq_n_f32(0.f); + } + } +} + +template +void mul_mat_q8_0_r4_q8_0_128(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + GGML_ASSERT(n == 128); + int8x16x4_t qx[8]; + float32x4_t scales[4]; + float32x4_t scales_y[4]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx); + for (int k = 0; k < 4; ++k) { + scales[k] = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[k].d)); + qx[2*k+0] = vld1q_s8_x4(iq8[k].qs); + qx[2*k+1] = vld1q_s8_x4(iq8[k].qs+64); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto by = (const block_q8_0_x4 *)info.src1_row(iy); + auto d8 = vcvt_f32_f16(vld1_f16((const float16_t *)by->d)); + scales_y[0] = vmulq_laneq_f32(scales[0], d8, 0); + scales_y[1] = vmulq_laneq_f32(scales[1], d8, 1); + scales_y[2] = vmulq_laneq_f32(scales[2], d8, 2); + scales_y[3] = vmulq_laneq_f32(scales[3], d8, 3); + auto sumf = vdupq_n_f32(0.f); + for (int k = 0; k < 4; ++k) { + auto y = vld1q_s8_x2(by->qs+32*k); + auto sumi = vdupq_n_s32(0); + sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[0], y.val[0], 0); + sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[1], y.val[1], 0); + sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[2], y.val[0], 1); + sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[3], y.val[1], 1); + sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[0], y.val[0], 2); + sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[1], y.val[1], 2); + sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[2], y.val[0], 3); + sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[3], y.val[1], 3); + sumf = vfmaq_f32(sumf, scales_y[k], vcvtq_f32_s32(sumi)); + } + info.store(ix, iy, sumf); + } + } +} + +#define SET_MUL_MAT_FUNCTIONS_T(m, func, Dequantizer) \ + m.funcs[0] = func;\ + m.funcs[1] = func;\ + m.funcs[2] = func;\ + m.funcs[3] = func;\ + m.funcs[4] = func;\ + m.funcs[5] = func;\ + m.funcs[6] = func;\ + m.funcs[7] = func;\ + +#define SET_MUL_MAT_FUNCTIONS(m, func) \ + m.funcs[0] = func<1>;\ + m.funcs[1] = func<2>;\ + m.funcs[2] = func<3>;\ + m.funcs[3] = func<4>;\ + m.funcs[4] = func<5>;\ + m.funcs[5] = func<6>;\ + m.funcs[6] = func<7>;\ + m.funcs[7] = func<8>;\ + template void MulMat::set_functions(MulMat& m) { if constexpr (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { - m.funcs[0] = mul_mat_qX_0_q8_0; - m.funcs[1] = mul_mat_qX_0_q8_0; - m.funcs[2] = mul_mat_qX_0_q8_0; - m.funcs[3] = mul_mat_qX_0_q8_0; - m.funcs[4] = mul_mat_qX_0_q8_0; - m.funcs[5] = mul_mat_qX_0_q8_0; - m.funcs[6] = mul_mat_qX_0_q8_0; - m.funcs[7] = mul_mat_qX_0_q8_0; + SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qX_0_q8_0, Dequantizer); } else if constexpr (std::is_same_v || std::is_same_v) { - m.funcs[0] = mul_mat_qX_1_q8_1; - m.funcs[1] = mul_mat_qX_1_q8_1; - m.funcs[2] = mul_mat_qX_1_q8_1; - m.funcs[3] = mul_mat_qX_1_q8_1; - m.funcs[4] = mul_mat_qX_1_q8_1; - m.funcs[5] = mul_mat_qX_1_q8_1; - m.funcs[6] = mul_mat_qX_1_q8_1; - m.funcs[7] = mul_mat_qX_1_q8_1; + SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qX_1_q8_1, Dequantizer); } else { - m.funcs[0] = mul_mat_qX_K_q8_K_T<1, Dequantizer>; - m.funcs[1] = mul_mat_qX_K_q8_K_T<2, Dequantizer>; - m.funcs[2] = mul_mat_qX_K_q8_K_T<3, Dequantizer>; - m.funcs[3] = mul_mat_qX_K_q8_K_T<4, Dequantizer>; - m.funcs[4] = mul_mat_qX_K_q8_K_T<5, Dequantizer>; - m.funcs[5] = mul_mat_qX_K_q8_K_T<6, Dequantizer>; - m.funcs[6] = mul_mat_qX_K_q8_K_T<7, Dequantizer>; - m.funcs[7] = mul_mat_qX_K_q8_K_T<8, Dequantizer>; + SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qX_K_q8_K_T, Dequantizer); } } @@ -6547,54 +12243,144 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { MulMat::set_functions(m); break; case GGML_TYPE_IQ1_BN: - m.funcs[0] = mul_mat_iq1bn_q8_K64<1>; - m.funcs[1] = mul_mat_iq1bn_q8_K64<2>; - m.funcs[2] = mul_mat_iq1bn_q8_K64<3>; - m.funcs[3] = mul_mat_iq1bn_q8_K64<4>; - m.funcs[4] = mul_mat_iq1bn_q8_K64<5>; - m.funcs[5] = mul_mat_iq1bn_q8_K64<6>; - m.funcs[6] = mul_mat_iq1bn_q8_K64<7>; - m.funcs[7] = mul_mat_iq1bn_q8_K64<8>; + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1bn_q8_K64); expected_Btype = GGML_TYPE_Q8_K64; break; case GGML_TYPE_IQ2_BN: - m.funcs[0] = mul_mat_iq2bn_q8_K64<1>; - m.funcs[1] = mul_mat_iq2bn_q8_K64<2>; - m.funcs[2] = mul_mat_iq2bn_q8_K64<3>; - m.funcs[3] = mul_mat_iq2bn_q8_K64<4>; - m.funcs[4] = mul_mat_iq2bn_q8_K64<5>; - m.funcs[5] = mul_mat_iq2bn_q8_K64<6>; - m.funcs[6] = mul_mat_iq2bn_q8_K64<7>; - m.funcs[7] = mul_mat_iq2bn_q8_K64<8>; + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2bn_q8_K64); expected_Btype = GGML_TYPE_Q8_K64; break; + case GGML_TYPE_IQ2_BN_R4: + m.funcs[0] = mul_mat_iq2_bn_r4_q8_k16<1>; + m.funcs[1] = mul_mat_iq2_bn_r4_q8_k16<2>; + m.funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>; + m.funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>; + m.funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>; + //m.funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>; + //m.funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>; + //m.funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>; + expected_Btype = GGML_TYPE_Q8_K16; + break; case GGML_TYPE_Q4_0: MulMat::set_functions(m); - expected_Btype = GGML_TYPE_Q8_0; + expected_Btype = GGML_TYPE_Q8_0_X4; break; case GGML_TYPE_Q4_1: MulMat::set_functions(m); - expected_Btype = GGML_TYPE_Q8_1; + expected_Btype = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_Q5_0: MulMat::set_functions(m); - expected_Btype = GGML_TYPE_Q8_0; + expected_Btype = GGML_TYPE_Q8_0_X4; break; case GGML_TYPE_Q5_1: MulMat::set_functions(m); - expected_Btype = GGML_TYPE_Q8_1; + expected_Btype = GGML_TYPE_Q8_1_X4; break; case GGML_TYPE_Q6_0: MulMat::set_functions(m); - expected_Btype = GGML_TYPE_Q8_0; + expected_Btype = GGML_TYPE_Q8_0_X4; break; case GGML_TYPE_Q8_0: MulMat::set_functions(m); - expected_Btype = GGML_TYPE_Q8_0; + expected_Btype = GGML_TYPE_Q8_0_X4; break; case GGML_TYPE_IQ4_NL: MulMat::set_functions(m); - expected_Btype = GGML_TYPE_Q8_0; + expected_Btype = GGML_TYPE_Q8_0_X4; + break; + case GGML_TYPE_IQ4_NL_R4: + SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, IQ4_NL_R4_Dequantizer); + expected_Btype = GGML_TYPE_Q8_0_X4; + break; + case GGML_TYPE_IQ4_XS_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_xs_r4_q8_k); + expected_Btype = GGML_TYPE_Q8_K32; + break; + case GGML_TYPE_IQ4_KS_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_ks_r4_q8_k); + expected_Btype = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_IQ2_XXS_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_xxs_r4_q8_k); + m.func16 = mul_mat_iq2_xxs_r4_q8_k<16>; + expected_Btype = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_IQ2_XS_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_xs_r4_q8_k); + m.func16 = mul_mat_iq2_xs_r4_q8_k<16>; + expected_Btype = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_IQ2_S_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_s_r4_q8_k); + m.func16 = mul_mat_iq2_s_r4_q8_k<16>; + expected_Btype = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_IQ3_XXS_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_xxs_r4_q8_k); + m.func16 = mul_mat_iq3_xxs_r4_q8_k<16>; + expected_Btype = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_IQ3_S_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_s_r4_q8_k); + m.func16 = mul_mat_iq3_s_r4_q8_k<16>; + expected_Btype = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_Q2_K_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_q2_k_r4_q8_k); + expected_Btype = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_Q3_K_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_q3_k_r4_q8_k); + expected_Btype = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_Q4_K_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_q4_k_r4_q8_k); + expected_Btype = GGML_TYPE_Q8_K32; + break; + case GGML_TYPE_Q5_K_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_q5_k_r4_q8_k); + expected_Btype = GGML_TYPE_Q8_K32; + break; + case GGML_TYPE_Q6_K_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_q6_k_r4_q8_k); + expected_Btype = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_Q8_K_R8: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_k_r8_q8_k); + expected_Btype = GGML_TYPE_Q8_KR8; + break; + case GGML_TYPE_IQ2_K_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_k_r4_q8_k); + expected_Btype = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_IQ3_K_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_k_r4_q8_k); + expected_Btype = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_IQ4_K_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_k_r4_q8_k); + expected_Btype = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_IQ5_K_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq5_k_r4_q8_k); + expected_Btype = GGML_TYPE_Q8_K; + break; + case GGML_TYPE_Q4_0_R4: + SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q4_0_R4_Dequantizer); + expected_Btype = GGML_TYPE_Q8_0_X4; + break; + case GGML_TYPE_Q5_0_R4: + SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q5_0_R4_Dequantizer); + expected_Btype = GGML_TYPE_Q8_0_X4; + break; + case GGML_TYPE_Q6_0_R4: + SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q6_0_R4_Dequantizer); + expected_Btype = GGML_TYPE_Q8_0_X4; + break; + case GGML_TYPE_Q8_0_R4: + SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_0_r4_q8_0); + expected_Btype = GGML_TYPE_Q8_0_X4; break; default: return false; @@ -6790,6 +12576,15 @@ struct F16 { static inline float reduce_add(Data data) { return _mm512_reduce_add_ps(data); } static inline Data max(Data v1, Data v2) { return _mm512_max_ps(v1, v2); } static inline Data add(Data v1, Data v2) { return _mm512_add_ps(v1, v2); } + static inline Data set4(const float * ptr) { + auto v128 = _mm_loadu_ps(ptr); + auto v256 = _mm256_set_m128(v128, v128); + return _mm512_insertf32x8(_mm512_castps256_ps512(v256), v256, 1); + } + static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x00), prev); } + static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0x55), prev); } + static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0xaa), prev); } + static inline Data fmadd_lane3(Data prev, Data v1, Data v2) { return _mm512_fmadd_ps(v1, _mm512_shuffle_ps(v2, v2, 0xff), prev); } #elif defined __AVX2__ using Data = __m256; constexpr static int block_size = 8; @@ -6807,10 +12602,18 @@ struct F16 { static inline float reduce_add(Data data) { return hsum_float_8(data); } static inline Data max(Data v1, Data v2) { return _mm256_max_ps(v1, v2); } static inline Data add(Data v1, Data v2) { return _mm256_add_ps(v1, v2); } + static inline Data set4(const float * ptr) { + auto v128 = _mm_loadu_ps(ptr); + return _mm256_set_m128(v128, v128); + } + static inline Data fmadd_lane0(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x00), prev); } + static inline Data fmadd_lane1(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0x55), prev); } + static inline Data fmadd_lane2(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0xaa), prev); } + static inline Data fmadd_lane3(Data prev, Data v1, Data v2) { return _mm256_fmadd_ps(v1, _mm256_shuffle_ps(v2, v2, 0xff), prev); } #else using Data = float16x8_t; constexpr static int block_size = 8; - constexpr static int num_registers = 32; + //constexpr static int num_registers = 32; constexpr static int q_step = 8; static inline Data zero() { return vdupq_n_f16(0); } static inline Data load(const char * ptr, int i) { return vld1q_f16((const float16_t *)ptr + block_size*i); } @@ -6838,6 +12641,14 @@ struct F16 { } static inline Data max(Data v1, Data v2) { return vmaxq_f16(v1, v2); } static inline Data add(Data v1, Data v2) { return vaddq_f16(v1, v2); } + static inline float16x4_t set4(const float * ptr) { + auto val32 = vld1q_f32(ptr); + return vcvt_f16_f32(val32); + } + static inline Data fmadd_lane0(Data prev, Data v1, float16x4_t v2) { return vfmaq_lane_f16(prev, v1, v2, 0); } + static inline Data fmadd_lane1(Data prev, Data v1, float16x4_t v2) { return vfmaq_lane_f16(prev, v1, v2, 1); } + static inline Data fmadd_lane2(Data prev, Data v1, float16x4_t v2) { return vfmaq_lane_f16(prev, v1, v2, 2); } + static inline Data fmadd_lane3(Data prev, Data v1, float16x4_t v2) { return vfmaq_lane_f16(prev, v1, v2, 3); } #endif template static inline float reduce_max(const Data * data) { return reduce_T(data); @@ -6886,304 +12697,138 @@ struct HelperF16 final : public BaseHelper { } }; -void quantize_row_q8_0(const float * x, block_q8_0 * y, int k) { - const int nb = k / QK8_0; - const int nb4 = 4*(nb/4); - -#if defined(__aarch64__) - block_q8_0_x4 * y4 = (block_q8_0_x4 *)y; - for (int i = 0; i < nb; i++) { - int i4 = i/4, ir = i%4; - float32x4_t srcv [8]; - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); - - const float amax = vmaxvq_f32(amaxv[0]); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - if (i < nb4) { - y4[i4].d[ir] = GGML_FP32_TO_FP16(d); - } else { - y[i].d = GGML_FP32_TO_FP16(d); - } - - for (int j = 0; j < 8; j++) { - const float32x4_t v = vmulq_n_f32(srcv[j], id); - const int32x4_t vi = vcvtnq_s32_f32(v); - - if (i < nb4) { - y4[i4].qs[32*ir + 4*j + 0] = vgetq_lane_s32(vi, 0); - y4[i4].qs[32*ir + 4*j + 1] = vgetq_lane_s32(vi, 1); - y4[i4].qs[32*ir + 4*j + 2] = vgetq_lane_s32(vi, 2); - y4[i4].qs[32*ir + 4*j + 3] = vgetq_lane_s32(vi, 3); - } else { - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); - } - } - } -#else - block_q8_0_x4 * y4 = (block_q8_0_x4 *)y; - for (int i = 0; i < nb; i++) { - int i4 = i/4, ir = i%4; - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x ); - __m256 v1 = _mm256_loadu_ps( x + 8 ); - __m256 v2 = _mm256_loadu_ps( x + 16 ); - __m256 v3 = _mm256_loadu_ps( x + 24 ); - x += 32; - - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); - - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); - max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float maxScalar = _mm_cvtss_f32( max4 ); - - const float d = maxScalar / 127.f; - if (i < nb4) { - y4[i4].d[ir] = GGML_FP32_TO_FP16(d); - } else { - y[i].d = GGML_FP32_TO_FP16(d); - } - const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; - const __m256 mul = _mm256_set1_ps( id ); - - v0 = _mm256_mul_ps( v0, mul ); - v1 = _mm256_mul_ps( v1, mul ); - v2 = _mm256_mul_ps( v2, mul ); - v3 = _mm256_mul_ps( v3, mul ); - - v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); - v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); - v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); - v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); - - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - - // Convert int32 to int16 - i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 - i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 - // Convert int16 to int8 - i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 - - // We got our precious signed bytes, but the order is now wrong - // These AVX2 pack instructions process 16-byte pieces independently - // The following instruction is fixing the order - const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); - i0 = _mm256_permutevar8x32_epi32( i0, perm ); - - if (i < nb4) { - _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0); - } else { - _mm256_storeu_si256((__m256i *)y[i].qs, i0); - } - } -#endif -} - -void quantize_row_q8_1(const float * x, block_q8_1 * y, int k) { - assert(k % QK8_1 == 0); - const int nb = k / QK8_1; - - const int nb4 = 4*(nb/4); - block_q8_1_x4 * y4 = (block_q8_1_x4 *)y; -#if defined(__aarch64__) - for (int i = 0; i < nb; i++) { - int i4 = i/4, ir = i%4; - float32x4_t srcv [8]; - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); - - const float amax = vmaxvq_f32(amaxv[0]); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - if (i < nb4) { - y4[i4].d[ir] = GGML_FP32_TO_FP16(d); - } else { - y[i].d = GGML_FP32_TO_FP16(d); - } - - int32x4_t accv = vdupq_n_s32(0); - - for (int j = 0; j < 8; j++) { - const float32x4_t v = vmulq_n_f32(srcv[j], id); - const int32x4_t vi = vcvtnq_s32_f32(v); - - if (i < nb4) { - y4[i4].qs[QK8_1*ir + 4*j + 0] = vgetq_lane_s32(vi, 0); - y4[i4].qs[QK8_1*ir + 4*j + 1] = vgetq_lane_s32(vi, 1); - y4[i4].qs[QK8_1*ir + 4*j + 2] = vgetq_lane_s32(vi, 2); - y4[i4].qs[QK8_1*ir + 4*j + 3] = vgetq_lane_s32(vi, 3); - } else { - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); - } - - accv = vaddq_s32(accv, vi); - } - - if (i < nb4) { - y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); - } else { - y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); - } - } -#else - for (int i = 0; i < nb; i++) { - int i4 = i/4, ir = i%4; - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x ); - __m256 v1 = _mm256_loadu_ps( x + 8 ); - __m256 v2 = _mm256_loadu_ps( x + 16 ); - __m256 v3 = _mm256_loadu_ps( x + 24 ); - x += 32; - - // Compute max(abs(e)) for the block - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); - - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); - max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float max_scalar = _mm_cvtss_f32( max4 ); - - // Quantize these floats - const float d = max_scalar / 127.f; - if (i < nb4) { - y4[i4].d[ir] = GGML_FP32_TO_FP16(d); - } else { - y[i].d = GGML_FP32_TO_FP16(d); - } - const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; - const __m256 mul = _mm256_set1_ps( id ); - - // Apply the multiplier - v0 = _mm256_mul_ps( v0, mul ); - v1 = _mm256_mul_ps( v1, mul ); - v2 = _mm256_mul_ps( v2, mul ); - v3 = _mm256_mul_ps( v3, mul ); - - // Round to nearest integer - v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); - v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); - v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); - v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); - - // Convert floats to integers - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - - // Compute the sum of the quants and set y[i].s - if (i < nb4) { - y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); - } else { - y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); - } - - // Convert int32 to int16 - i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 - i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 - // Convert int16 to int8 - i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 - - // We got our precious signed bytes, but the order is now wrong - // These AVX2 pack instructions process 16-byte pieces independently - // The following instruction is fixing the order - const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); - i0 = _mm256_permutevar8x32_epi32( i0, perm ); - - if (i < nb4) { - _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0); - } else { - _mm256_storeu_si256((__m256i *)y[i].qs, i0); - } - } -#endif -} - template struct HelperQ80 final : public BaseHelper { using Base = BaseHelper; +#ifdef HAVE_FANCY_SIMD + using block_q8 = block_q8_1; +#else using block_q8 = block_q8_0; +#endif HelperQ80(const char * data, int stride) : Base(data, stride) {} // Needed for v * softmax(k * q) inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const { int j = F16::block_size*i; - auto dl = (const block_q8_0_x4 *)Base::lblock(l1) + j/(4*QK8_0); - int ii = (j/QK8_0)%4; + auto dl = (const block_q8_0 *)Base::lblock(l1) + j/QK8_0; #ifdef __aarch64__ - const float16_t * d = (const float16_t *)dl->d; - auto vd = F16::set1(d[ii]); - auto qs = vld1_s8_x2(dl->qs + 32*ii + j%32); + auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); + int ii = j%QK8_0; + auto qs = vld1_s8_x2(dl->qs + ii); v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[0]))); v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(qs.val[1]))); #else - auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d[ii])); + auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); #ifdef HAVE_FANCY_SIMD - v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+2*ii+0)))); - v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+2*ii+1)))); + v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+0)))); + v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)dl->qs+1)))); #else - v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+32*ii+j%32))))); - v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+32*ii+j%32+8))))); + int ii = j%QK8_0; + v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+ii+0))))); + v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(dl->qs+ii+8))))); #endif #endif } static inline void convert(int nq, int stride_q, const float * q, block_q8_0 * y) { - GGML_ASSERT(nq <= step); + //GGML_ASSERT(nq <= step); Why did I have this assert? for (int i = 0; i < nq; ++i) { - quantize_row_q8_0(q, y, D); + quantize_row_q8_0_x4(q, y, D); q += stride_q; y += D/QK8_0; } } static inline void convert(int nq, int stride_q, const float * q, block_q8_1 * y) { - GGML_ASSERT(nq <= step); + //GGML_ASSERT(nq <= step); Why did I have this assert? for (int i = 0; i < nq; ++i) { - quantize_row_q8_1(q, y, D); + quantize_row_q8_1_x4(q, y, D); q += stride_q; y += D/QK8_1; } } }; +template +struct HelperQ80R4 : public BaseHelper { + using Base = BaseHelper; +#ifdef __AVX2__ + using block_q8 = block_q8_1; +#else + using block_q8 = block_q8_0; +#endif + HelperQ80R4(int nk, const HelperQ80& q8) : Base(q8.data, q8.stride) { + r4 = repack(nk, q8); + Base::data = (const char *)r4.data(); + Base::stride = (D/QK8_0)*sizeof(block_q8_0); + } + + static std::vector repack(int nk, const HelperQ80 q8) { + static_assert(D%QK8_0 == 0); + GGML_ASSERT(nk%4 == 0); + constexpr int nblock = D/QK8_0; + std::vector result(nblock * nk/4); + auto y = result.data(); + const block_q8_0 * x4[4]; + for (int row = 0; row < nk; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = (const block_q8_0 *)(q8.data + (row + k)*q8.stride); + for (int ib = 0; ib < nblock; ++ib) { + for (int k = 0; k < 4; ++k) y[ib].d[k] = x4[k][ib].d; +#ifdef __AVX2__ + auto m0 = _mm256_loadu_si256((const __m256i *)x4[0][ib].qs); + auto m1 = _mm256_loadu_si256((const __m256i *)x4[1][ib].qs); + auto m2 = _mm256_loadu_si256((const __m256i *)x4[2][ib].qs); + auto m3 = _mm256_loadu_si256((const __m256i *)x4[3][ib].qs); + auto t0 = _mm256_unpacklo_epi32(m0, m1); + auto t1 = _mm256_unpacklo_epi32(m2, m3); + auto t2 = _mm256_unpackhi_epi32(m0, m1); + auto t3 = _mm256_unpackhi_epi32(m2, m3); + m0 = _mm256_unpacklo_epi64(t0, t1); + m1 = _mm256_unpackhi_epi64(t0, t1); + m2 = _mm256_unpacklo_epi64(t2, t3); + m3 = _mm256_unpackhi_epi64(t2, t3); + _mm256_storeu_si256((__m256i *)y[ib].qs + 0, m0); + _mm256_storeu_si256((__m256i *)y[ib].qs + 1, m1); + _mm256_storeu_si256((__m256i *)y[ib].qs + 2, m2); + _mm256_storeu_si256((__m256i *)y[ib].qs + 3, m3); +#elif defined __ARM_NEON + auto m0 = vld1q_s8_x2(x4[0][ib].qs); + auto m1 = vld1q_s8_x2(x4[1][ib].qs); + auto m2 = vld1q_s8_x2(x4[2][ib].qs); + auto m3 = vld1q_s8_x2(x4[3][ib].qs); + auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0])); + auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0])); + m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1])); + row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1])); + m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + vst1q_s8_x2(y[ib].qs + 0, m0); + vst1q_s8_x2(y[ib].qs + 32, m1); + vst1q_s8_x2(y[ib].qs + 64, m2); + vst1q_s8_x2(y[ib].qs + 96, m3); +#else + for (int l = 0; l < 4; ++l) { + for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) { + y[ib].qs[32*l+4*k+i+ 0] = x4[k][ib].qs[i+4*l+ 0]; + y[ib].qs[32*l+4*k+i+16] = x4[k][ib].qs[i+4*l+16]; + } + } +#endif + } + y += nblock; + } + return result; + } + + std::vector r4; +}; + template struct HelperQ40 final : public BaseHelper { using Base = BaseHelper; @@ -7406,7 +13051,7 @@ struct FlashMS { if (smax > M[j]) { if (M[j] > -INFINITY) { float m = expf(M[j] - smax); - vms[j] = F16::set1(m); + vms[j] = m; need_scaling[j] = 1; S[j] *= m; } else { @@ -7588,7 +13233,7 @@ struct FlashMS { cache_t cache[q_step*k_step]; float S[q_step], M[q_step]; int need_scaling[q_step]; - F16::Data vms[q_step]; + float vms[q_step]; const F16::Data vscale; const float softcap; const ggml_half h_inf; @@ -7608,75 +13253,90 @@ struct FlashQKV { // Hence, for now, we will not handle head sizes of 80 and 112 template inline void accumulate_qkv(const VHelper& vh, const FlashMS& fms) { - F16::Data vk[2*q_step]; + F16::Data v[8]; + for (int j = 0; j < q_step; ++j) { + auto R = qkv_cache + D*j; + if (fms.need_scaling[j] == 2) { + std::memset(R, 0, D*sizeof(qkv_cache_t)); + } + else if (fms.need_scaling[j] == 1) { + auto vms = F16::set1(fms.vms[j]); + for (int i = 0; i < D/F16::block_size; ++i) { + F16::store(R + F16::block_size*i, F16::mul(vms, F16::load(R + F16::block_size*i))); + } + } + } for (int i = 0; i < D/F16::block_size; i += 2) { - for (int j = 0; j < q_step; ++j) { - if (fms.need_scaling[j] == 2) { - vk[2*j+0] = vk[2*j+1] = F16::zero(); - } else { - auto R = qkv_cache + D*j; - vk[2*j+0] = F16::load(R + F16::block_size*i); - vk[2*j+1] = F16::load(R + F16::block_size*(i + 1)); - if (fms.need_scaling[j] == 1) { - vk[2*j+0] = F16::mul(vk[2*j+0], fms.vms[j]); - vk[2*j+1] = F16::mul(vk[2*j+1], fms.vms[j]); - } - } - } - F16::Data v1, v2; - for (int l1 = 0; l1 < k_step; ++l1) { - vh.load(l1, i, v1, v2); + for (int l = 0; l < k_step; l += 4) { + vh.load(l+0, i, v[0], v[4]); + vh.load(l+1, i, v[1], v[5]); + vh.load(l+2, i, v[2], v[6]); + vh.load(l+3, i, v[3], v[7]); for (int j = 0; j < q_step; ++j) { - auto vs = F16::set1(fms.cache[k_step*j + l1]); - vk[2*j+0] = F16::fmadd(vk[2*j+0], v1, vs); - vk[2*j+1] = F16::fmadd(vk[2*j+1], v2, vs); + auto R = qkv_cache + D*j; + auto s1 = F16::load(R + F16::block_size*(i+0)); + auto s2 = F16::load(R + F16::block_size*(i+1)); + auto vs = F16::set4(fms.cache + k_step*j + l); + s1 = F16::fmadd_lane0(s1, v[0], vs); + s2 = F16::fmadd_lane0(s2, v[4], vs); + s1 = F16::fmadd_lane1(s1, v[1], vs); + s2 = F16::fmadd_lane1(s2, v[5], vs); + s1 = F16::fmadd_lane2(s1, v[2], vs); + s2 = F16::fmadd_lane2(s2, v[6], vs); + s1 = F16::fmadd_lane3(s1, v[3], vs); + s2 = F16::fmadd_lane3(s2, v[7], vs); + F16::store(R + F16::block_size*(i+0), s1); + F16::store(R + F16::block_size*(i+1), s2); } } - for (int j = 0; j < q_step; ++j) { - auto R = qkv_cache + D*j; - F16::store(R + F16::block_size*(i + 0), vk[2*j+0]); - F16::store(R + F16::block_size*(i + 1), vk[2*j+1]); - } } } - template = 2>> + template inline void accumulate_qkv(int nq1, const VHelper& vh, const FlashMS& fms) { - F16::Data vk[2*q_step]; + F16::Data v[8]; + for (int j = 0; j < nq1; ++j) { + auto R = qkv_cache + D*j; + if (fms.need_scaling[j] == 2) { + std::memset(R, 0, D*sizeof(qkv_cache_t)); + } + else if (fms.need_scaling[j] == 1) { + auto vms = F16::set1(fms.vms[j]); + for (int i = 0; i < D/F16::block_size; ++i) { + F16::store(R + F16::block_size*i, F16::mul(vms, F16::load(R + F16::block_size*i))); + } + } + } for (int i = 0; i < D/F16::block_size; i += 2) { - for (int j = 0; j < nq1; ++j) { - if (fms.need_scaling[j] == 2) { - vk[2*j+0] = vk[2*j+1] = F16::zero(); - } else { - auto R = qkv_cache + D*j; - vk[2*j+0] = F16::load(R + F16::block_size*i); - vk[2*j+1] = F16::load(R + F16::block_size*(i + 1)); - if (fms.need_scaling[j] == 1) { - vk[2*j+0] = F16::mul(vk[2*j+0], fms.vms[j]); - vk[2*j+1] = F16::mul(vk[2*j+1], fms.vms[j]); - } - } - } - F16::Data v1, v2; - for (int l1 = 0; l1 < k_step; ++l1) { - vh.load(l1, i, v1, v2); + for (int l = 0; l < k_step; l += 4) { + vh.load(l+0, i, v[0], v[4]); + vh.load(l+1, i, v[1], v[5]); + vh.load(l+2, i, v[2], v[6]); + vh.load(l+3, i, v[3], v[7]); for (int j = 0; j < nq1; ++j) { - auto vs = F16::set1(fms.cache[k_step*j + l1]); - vk[2*j+0] = F16::fmadd(vk[2*j+0], v1, vs); - vk[2*j+1] = F16::fmadd(vk[2*j+1], v2, vs); + auto R = qkv_cache + D*j; + auto s1 = F16::load(R + F16::block_size*(i+0)); + auto s2 = F16::load(R + F16::block_size*(i+1)); + auto vs = F16::set4(fms.cache + k_step*j + l); + s1 = F16::fmadd_lane0(s1, v[0], vs); + s2 = F16::fmadd_lane0(s2, v[4], vs); + s1 = F16::fmadd_lane1(s1, v[1], vs); + s2 = F16::fmadd_lane1(s2, v[5], vs); + s1 = F16::fmadd_lane2(s1, v[2], vs); + s2 = F16::fmadd_lane2(s2, v[6], vs); + s1 = F16::fmadd_lane3(s1, v[3], vs); + s2 = F16::fmadd_lane3(s2, v[7], vs); + F16::store(R + F16::block_size*(i+0), s1); + F16::store(R + F16::block_size*(i+1), s2); } } - for (int j = 0; j < nq1; ++j) { - auto R = qkv_cache + D*j; - F16::store(R + F16::block_size*(i + 0), vk[2*j+0]); - F16::store(R + F16::block_size*(i + 1), vk[2*j+1]); - } } } inline void normalize_and_store(const FlashMS& fms, int j, const qkv_cache_t * R, float * qkv) const { GGML_ASSERT(fms.S[j] > 0); auto norm = F16::set1(1/fms.S[j]); + //auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f); for (int i = 0; i < D/F16::block_size; ++i) { auto r = F16::load(R + F16::block_size*i); F16::store(qkv + F16::block_size*i, F16::mul(norm, r)); @@ -7701,156 +13361,281 @@ struct FlashQKV { } } - qkv_cache_t qkv_cache[D*q_step]; + // qkv_cache_t qkv_cache[D*q_step]; + // The initializer is not actually required. But the compiler cannot figure out that when qkv_cache is + // first used for q_step rows, fms.need_scaling[j] is always 2, which zeroes the content of qkv_cache. + // As a result, we get an infinite stream of warnings about uninitialized variable use (one for each + // combination of D, q_step, k_step), which is extremely annoying. Hence, I succumb to the trend of + // constantly being saved by others (the compiler in this case), and add this 100% unnecessary initialization. + qkv_cache_t qkv_cache[D*q_step] = {}; }; +#ifdef HAVE_FANCY_SIMD +template +static void mul_mat_q8_0_r4_q8_1_128([[maybe_unused]] int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + GGML_ASSERT(n == 128); + //Q8 q8(info); + __m512i qx[16]; + __m512 scales[4]; + __m512 scales_m[4]; + __m512 dy[4]; + auto m127 = _mm512_set1_epi8(127); + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q8_0_x4 * q8l = (const block_q8_0_x4 *)((const char *)vx + (ix+0)*bx); + const block_q8_0_x4 * q8h = (const block_q8_0_x4 *)((const char *)vx + (ix+4)*bx); + for (int k = 0; k < 4; ++k) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8l[k].d)); + auto scales1 = _mm256_set_m128(scales128, scales128); + scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)q8h[k].d)); + auto scales2 = _mm256_set_m128(scales128, scales128); + scales[k] = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + scales_m[k] = _mm512_mul_ps(scales[k], _mm512_set1_ps(-63.5f)); + qx[4*k+0] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+0)), + _mm256_loadu_si256((const __m256i *)q8h[k].qs+0), 1); + qx[4*k+1] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+1)), + _mm256_loadu_si256((const __m256i *)q8h[k].qs+1), 1); + qx[4*k+2] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+2)), + _mm256_loadu_si256((const __m256i *)q8h[k].qs+2), 1); + qx[4*k+3] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[k].qs+3)), + _mm256_loadu_si256((const __m256i *)q8h[k].qs+3), 1); + qx[4*k+0] = _mm512_add_epi8(qx[4*k+0], m127); + qx[4*k+1] = _mm512_add_epi8(qx[4*k+1], m127); + qx[4*k+2] = _mm512_add_epi8(qx[4*k+2], m127); + qx[4*k+3] = _mm512_add_epi8(qx[4*k+3], m127); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto by = (const block_q8_1_x4 *)info.src1_row(iy); + //auto dall = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8.y[iy][0].d)); + auto dall = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)by->d)); + auto d128 = _mm256_castps256_ps128(dall); + auto m128 = _mm256_extractf128_ps(dall, 1); + auto m256 = _mm256_set_m128(m128, m128); + auto m512 = _mm512_insertf32x8(_mm512_castps256_ps512(m256), m256, 1); + auto sumf = _mm512_mul_ps(scales_m[0], _mm512_shuffle_ps(m512, m512, 0x00)); + sumf = _mm512_fmadd_ps(scales_m[1], _mm512_shuffle_ps(m512, m512, 0x55), sumf); + sumf = _mm512_fmadd_ps(scales_m[2], _mm512_shuffle_ps(m512, m512, 0xaa), sumf); + sumf = _mm512_fmadd_ps(scales_m[3], _mm512_shuffle_ps(m512, m512, 0xff), sumf); + auto d256 = _mm256_set_m128(d128, d128); + auto d512 = _mm512_insertf32x8(_mm512_castps256_ps512(d256), d256, 1); + dy[0] = _mm512_mul_ps(scales[0], _mm512_shuffle_ps(d512, d512, 0x00)); + dy[1] = _mm512_mul_ps(scales[1], _mm512_shuffle_ps(d512, d512, 0x55)); + dy[2] = _mm512_mul_ps(scales[2], _mm512_shuffle_ps(d512, d512, 0xaa)); + dy[3] = _mm512_mul_ps(scales[3], _mm512_shuffle_ps(d512, d512, 0xff)); + for (int k = 0; k < 4; ++k) { + //auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][0].qs+k); + auto y8 = _mm256_loadu_si256((const __m256i*)by->qs+k); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[4*k+3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + sumf = _mm512_fmadd_ps(dy[k], _mm512_cvtepi32_ps(sumi), sumf); + } + auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sumf, 0), _mm512_extractf32x4_ps(sumf, 1)); + auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sumf, 2), _mm512_extractf32x4_ps(sumf, 3)); + info.store(ix+0, iy, sum1); + info.store(ix+4, iy, sum2); + } + } +} +#endif + template struct FlashQKfp32 { static_assert(D%F16::block_size == 0 && D <= 256); static_assert(k_step%F16::block_size == 0); static_assert(q_step <= 4 || q_step%4 == 0); -#ifdef __aarch64__ - constexpr static bool is_small_head = false; -#else - constexpr static bool is_small_head = D <= (F16::num_registers/2)*F16::block_size; -#endif - - template , typename q_float> - static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask, - F16::Data * qv, F16::Data * vk, FlashMS& fms) { - // q index is q_step*i1 + m1 - // k index is k_step*k1 + l1 - const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1); - fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] = -INFINITY; - if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf) { - return; - } - auto qr = q + m1*stride_q; - for (int i = 0; i < D/F16::block_size; ++i) qv[i] = F16::load(qr + F16::block_size*i); - if (mp[l1+0] != fms.h_inf) { - auto vsum = F16::zero(); - for (int i = 0; i < D/F16::block_size; ++i) vsum = F16::fmadd(vsum, vk[i], qv[i]); - fms.cache[k_step*m1 + l1 + 0] = F16::reduce_add(vsum); - } - if (mp[l1+1] != fms.h_inf) { - auto vsum = F16::zero(); - for (int i = 0; i < D/F16::block_size; ++i) vsum = F16::fmadd(vsum, vk[i+D/16], qv[i]); - fms.cache[k_step*m1 + l1 + 1] = F16::reduce_add(vsum); - } - } - - template , typename q_float> - static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask, - F16::Data * vk, FlashMS& fms) { - // q index is q_step*i1 + m1 - // k index is k_step*k1 + l1 - const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1); - if (mp[l1] == fms.h_inf) { - fms.cache[k_step*m1 + l1] = -INFINITY; - return; - } - auto qr = q + m1*stride_q; - auto vsum = F16::zero(); - for (int i = 0; i < D/F16::block_size; ++i) { - vsum = F16::fmadd(vsum, vk[i], F16::load(qr + F16::block_size*i)); - } - fms.cache[k_step*m1 + l1] = F16::reduce_add(vsum); - } - - template , typename q_float> - static inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, - FlashMS& fms) { - F16::Data qv[D/F16::block_size]; - F16::Data vk[D/(F16::block_size/2)]; - for (int l1 = 0; l1 < k_step; l1 += 2) { - kh.load_2(l1, vk); - for (int m1 = 0; m1 < q_step; ++m1) { - mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv, vk, fms); - } - } - } - - template , typename q_float> - static inline void mult_mask_kq_l(const KHelper& kh, int stride_q, int stride_m, - const q_float * q, const char * mask, FlashMS& fms) { - F16::Data vk[D/F16::block_size]; - for (int l1 = 0; l1 < k_step; ++l1) { - kh.load(l1, vk); - for (int m1 = 0; m1 < q_step; ++m1) { - mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, vk, fms); - } - } - } - - template , typename q_float> - static inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, - FlashMS& fms) { - F16::Data qv[D/F16::block_size]; - F16::Data vk[D/(F16::block_size/2)]; - for (int l1 = 0; l1 < k_step; l1 += 2) { - kh.load_2(l1, vk); - for (int m1 = 0; m1 < nq; ++m1) { - mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv, vk, fms); - } - } - } - - template , typename q_float> - static inline void mult_mask_kq_l(int nq, const KHelper& kh, int stride_q, int stride_m, - const q_float * q, const char * mask, FlashMS& fms) { - F16::Data vk[D/F16::block_size]; - for (int l1 = 0; l1 < k_step; ++l1) { - kh.load(l1, vk); - for (int m1 = 0; m1 < nq; ++m1) { - mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, vk, fms); - } - } - } - +#ifdef __AVX2__ template static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, FlashMS& fms) { - if constexpr (is_small_head) { - mult_mask_kq(kh, stride_q, stride_m, q, mask, fms); - } - else { - mult_mask_kq_l(kh, stride_q, stride_m, q, mask, fms); - } -#ifdef __aarch64__ - float32x4_t vk[k_step/4]; - for (int j = 0; j < q_step; ++j) { - fms.update_M_S(j, vk); - } +#ifdef HAVE_FANCY_SIMD + constexpr int nrc_q = 8; + constexpr int nrc_k = 8; #else + // somewhat surprisingly, nrc_q = 4, nrc_k = 8 is better than nrc_q = 8, nrc_k = 4 + constexpr int nrc_q = 4; + constexpr int nrc_k = 8; +#endif + constexpr int qrem = q_step - nrc_q*(q_step/nrc_q); + constexpr int krem = k_step - nrc_k*(k_step/nrc_k); + DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; + for (int iq = 0; iq < q_step/nrc_q; ++iq) { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); + } + if constexpr (krem > 0) { + mul_mat_Qx_Qy_MxN_fa, QFT>(D, kh.block, kh.stride, k_step - krem, info); + } + info.cur_y += nrc_q; + } + if constexpr (qrem > 0) { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); + } + if constexpr (krem > 0) { + mul_mat_Qx_Qy_MxN_fa, QFT>(D, kh.block, kh.stride, k_step - krem, info); + } + } F16::Data vk[k_step/F16::block_size]; for (int j = 0; j < q_step; ++j) { - fms.update_M_S(j, vk); + fms.update_M_S(j, vk, mask + stride_m*j); } -#endif } +#else + template + static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, + FlashMS& fms) { + constexpr int nrc_q = 4; + constexpr int nrc_k = 6; + constexpr int qrem = q_step - nrc_q*(q_step/nrc_q); + constexpr int krem = k_step - nrc_k*(k_step/nrc_k); + DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; + for (int iq = 0; iq < q_step/nrc_q; ++iq) { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_f16_f16_NxN(D, kh.block, kh.stride, ik*nrc_k, info); + } + if constexpr (krem > 0) { + mul_mat_f16_f16_NxN(D, kh.block, kh.stride, k_step - krem, info); + } + info.cur_y += nrc_q; + } + if constexpr (qrem > 0) { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_f16_f16_NxN(D, kh.block, kh.stride, ik*nrc_k, info); + } + if constexpr (krem > 0) { + mul_mat_f16_f16_NxN(D, kh.block, kh.stride, k_step - krem, info); + } + } + float32x4_t vk[k_step/4]; + for (int j = 0; j < q_step; ++j) { + fms.update_M_S(j, vk, mask + stride_m*j); + } + } +#endif +#ifdef __AVX2__ template static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, FlashMS& fms) { - if constexpr (is_small_head) { - mult_mask_kq(nq, kh, stride_q, stride_m, q, mask, fms); - } - else { - mult_mask_kq_l(nq, kh, stride_q, stride_m, q, mask, fms); - } -#ifdef __aarch64__ - float32x4_t vk[k_step/4]; - for (int j = 0; j < nq; ++j) { - fms.update_M_S(j, vk); - } +#ifdef HAVE_FANCY_SIMD + constexpr int nrc_q = 8; + constexpr int nrc_k = 8; #else - F16::Data vk[k_step/F16::block_size]; - for (int j = 0; j < nq; ++j) { - fms.update_M_S(j, vk); - } + // somewhat surprisingly, nrc_q = 4, nrc_k = 8 is better than nrc_q = 8, nrc_k = 4 + constexpr int nrc_q = 4; + constexpr int nrc_k = 8; #endif + static_assert(k_step%nrc_k == 0); + int qrem = q_step - nrc_q*(q_step/nrc_q); + DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; + for (int iq = 0; iq < nq/nrc_q; ++iq) { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); + } + info.cur_y += nrc_q; + } + if (qrem > 0) { + switch (qrem) { + case 1: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); + } + } break; + case 2: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); + } + } break; + case 3: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); + } + } break; +#ifdef HAVE_FANCY_SIMD + case 4: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); + } + } break; + case 5: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); + } + } break; + case 6: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); + } + } break; + case 7: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); + } + } break; +#endif + } + } + F16::Data vk[k_step/F16::block_size]; + for (int j = 0; j < q_step; ++j) { + fms.update_M_S(j, vk, mask + stride_m*j); + } } +#else + template + static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, + FlashMS& fms) { + constexpr int nrc_q = 4; + constexpr int nrc_k = 6; + constexpr int krem = k_step - nrc_k*(k_step/nrc_k); + const int qrem = q_step - nrc_q*(q_step/nrc_q); + DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; + for (int iq = 0; iq < nq/nrc_q; ++iq) { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_f16_f16_NxN(D, kh.block, kh.stride, ik*nrc_k, info); + } + if constexpr (krem > 0) { + mul_mat_f16_f16_NxN(D, kh.block, kh.stride, k_step - krem, info); + } + info.cur_y += nrc_q; + } + switch (qrem) { + case 0: break; + case 1: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_f16_f16_NxN<1, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info); + } + if constexpr (krem > 0) { + mul_mat_f16_f16_NxN<1, krem, true>(D, kh.block, kh.stride, k_step - krem, info); + } + } break; + case 2: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_f16_f16_NxN<2, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info); + } + if constexpr (krem > 0) { + mul_mat_f16_f16_NxN<2, krem, true>(D, kh.block, kh.stride, k_step - krem, info); + } + } break; + case 3: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_f16_f16_NxN<3, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info); + } + if constexpr (krem > 0) { + mul_mat_f16_f16_NxN<3, krem, true>(D, kh.block, kh.stride, k_step - krem, info); + } + } break; + } + float32x4_t vk[k_step/4]; + for (int j = 0; j < q_step; ++j) { + fms.update_M_S(j, vk, mask + stride_m*j); + } + } +#endif #ifdef __aarch64__ static inline void convert(int nq, int stride_q, const float * q, float16_t * q_f16) { @@ -7867,57 +13652,138 @@ struct FlashQKfp32 { } #endif - template - static inline void mul_mask_kq(const KHelper& kh, int stride_m, - const block_q8 * q, const char * mask, FlashMS& fms) { - static_assert(q_step <= 8); + template + static inline std::pair mul_mat_kernel(int nq) { + constexpr int kMaxQ = 8; +#define MAKE_FUNCS(mul_mat, n) \ + if (n >= kMaxQ) return std::make_pair(mul_mat, kMaxQ>, kMaxQ);\ + else {\ + switch (n) {\ + case 1: return std::make_pair(mul_mat, 1>, 1);\ + case 2: return std::make_pair(mul_mat, 2>, 2);\ + case 3: return std::make_pair(mul_mat, 3>, 3);\ + case 4: return std::make_pair(mul_mat, 4>, 4);\ + case 5: return std::make_pair(mul_mat, 5>, 5);\ + case 6: return std::make_pair(mul_mat, 6>, 6);\ + case 7: return std::make_pair(mul_mat, 7>, 7);\ + }\ + } +#define MAKE_FUNCS_ONLY_NRC(mul_mat, n) \ + if (n >= kMaxQ) return std::make_pair(mul_mat, kMaxQ);\ + else {\ + switch (n) {\ + case 1: return std::make_pair(mul_mat<1>, 1);\ + case 2: return std::make_pair(mul_mat<2>, 2);\ + case 3: return std::make_pair(mul_mat<3>, 3);\ + case 4: return std::make_pair(mul_mat<4>, 4);\ + case 5: return std::make_pair(mul_mat<5>, 5);\ + case 6: return std::make_pair(mul_mat<6>, 6);\ + case 7: return std::make_pair(mul_mat<7>, 7);\ + }\ + } if constexpr (std::is_same_v>) { - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; #ifdef __aarch64__ - mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); + MAKE_FUNCS(mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); + MAKE_FUNCS(mul_mat_qX_0_q8_0_T>) { - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; #ifdef __aarch64__ - mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); + MAKE_FUNCS(mul_mat_qX_0_q8_0= 128) { - mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); +#ifdef HAVE_FANCY_SIMD + MAKE_FUNCS(mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); + // This does not actually work until we fix K-cache to be quantized to Q8_0_x4 only if D%128 == 0 + MAKE_FUNCS(mul_mat_qX_0_q8_0_T>) { - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; + else if constexpr (std::is_same_v>) { #ifdef __aarch64__ - mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); + if constexpr (D == 128) { + if (q_step >= 64 && nq >= 64) { + return std::make_pair(mul_mat_q8_0_r4_q8_0_128<64>, 64); + } + else if (q_step >= 32 && nq >= 32) { + return std::make_pair(mul_mat_q8_0_r4_q8_0_128<32>, 32); + } + else if (q_step >= 16 && nq >= 16) { + return std::make_pair(mul_mat_q8_0_r4_q8_0_128<16>, 16); + } + else { + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0_128, nq); + } + } else { + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq); + } + //MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq); #else - mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); +#ifdef HAVE_FANCY_SIMD + if constexpr (D == 128) { + if (q_step >= 64 && nq >= 64) { + return std::make_pair(mul_mat_q8_0_r4_q8_1_128<64>, 64); + } + else if (q_step >= 32 && nq >= 32) { + return std::make_pair(mul_mat_q8_0_r4_q8_1_128<32>, 32); + } + else if (q_step >= 16 && nq >= 16) { + return std::make_pair(mul_mat_q8_0_r4_q8_1_128<16>, 16); + } + else { + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_1_128, nq); + } + } else { + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_1, nq); + } +#else + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_1, nq); +#endif +#endif + } + else if constexpr (std::is_same_v>) { +#ifdef __aarch64__ + MAKE_FUNCS(mul_mat_qX_1_q8_1>) { - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; #ifdef __aarch64__ - mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); + MAKE_FUNCS(mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); + MAKE_FUNCS(mul_mat_qX_1_q8_1_T>) { - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; #ifdef __aarch64__ - mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); + MAKE_FUNCS(mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); + MAKE_FUNCS(mul_mat_qX_1_q8_1_T(nullptr, 0); + } + + template + static inline void mul_mask_kq(const KHelper& kh, int stride_m, + const block_q8 * q, const char * mask, FlashMS& fms) { + constexpr int kMaxQ = 8; + static_assert(q_step < kMaxQ || q_step%kMaxQ == 0); + auto [mul_mat, nrc_q] = mul_mat_kernel(q_step); + DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; + for (int iq = 0; iq < q_step/nrc_q; ++iq) { + mul_mat(D, kh.block, kh.stride, info, k_step); + info.cur_y += nrc_q; + } #ifdef __aarch64__ float32x4_t vk[k_step/4]; for (int j = 0; j < q_step; ++j) { @@ -7930,136 +13796,21 @@ struct FlashQKfp32 { } #endif } + template static inline void mul_mask_kq(int nq, const KHelper& kh, int stride_m, const block_q8 * q, const char * mask, FlashMS& fms) { - GGML_ASSERT(nq < 8); - if constexpr (std::is_same_v>) { - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; - switch (nq) { -#ifdef __aarch64__ - case 1: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; -#else - case 1: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; -#endif - } + auto [mul_mat, nrc_q] = mul_mat_kernel(nq); + DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; + for (int iq = 0; iq < nq/nrc_q; ++iq) { + mul_mat(D, kh.block, kh.stride, info, k_step); + info.cur_y += nrc_q; } - else if constexpr (std::is_same_v>) { - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; -#ifdef __aarch64__ - switch (nq) { - case 1: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - } -#else - if constexpr (D >= 128) { - switch (nq) { - case 1: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - } - } else { - switch (nq) { - case 1: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; - } - } -#endif - } - else if constexpr (std::is_same_v>) { - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; - switch (nq) { -#ifdef __aarch64__ - case 1: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; -#else - case 1: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; -#endif - } - } - else if constexpr (std::is_same_v>) { - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; - switch (nq) { -#ifdef __aarch64__ - case 1: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; -#else - case 1: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; -#endif - } - } - else if constexpr (std::is_same_v>) { - DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; - switch (nq) { -#ifdef __aarch64__ - case 1: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; -#else - case 1: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; -#endif - } - } - else { - GGML_ASSERT(false); + int iq = nrc_q*(nq/nrc_q); + if (iq < nq) { + auto [mul_mat1, nrc_q1] = mul_mat_kernel(nq - iq); + GGML_ASSERT(nrc_q1 == nq - iq); + mul_mat1(D, kh.block, kh.stride, info, k_step); } #ifdef __aarch64__ float32x4_t vk[k_step/4]; @@ -8083,6 +13834,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in #ifdef __aarch64__ float16_t q_f16[D*q_step]; #endif + for (int i1 = 0; i1 < nq1/q_step; ++i1) { fms.init_qstep(); kh.reset_block(); @@ -8138,20 +13890,44 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, FlashQKV& fqkv, const float * q, const char * mask, float * qkv) { typename KHelper::block_q8 q8[q_step*(D/QK8_0)]; +#if FA_TIMING + Perf perf(false); +#endif for (int i1 = 0; i1 < nq1/q_step; ++i1) { +#if FA_TIMING + auto t1 = Perf::cur_time(); +#endif fms.init_qstep(); kh.reset_block(); vh.reset_block(); HelperQ80::convert(q_step, stride_q, q, q8); +#if FA_TIMING + perf.accum_nolock(0, t1); +#endif auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { +#if FA_TIMING + t1 = Perf::cur_time(); + KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms); + perf.accum_nolock(1, t1); + t1 = Perf::cur_time(); + fqkv.accumulate_qkv(vh, fms); + perf.accum_nolock(2, t1); +#else KQHelper::mul_mask_kq(kh, stride_m, q8, mr, fms); fqkv.accumulate_qkv(vh, fms); +#endif kh.next_block(); vh.next_block(); mr += k_step*sizeof(ggml_half); } +#if FA_TIMING + t1 = Perf::cur_time(); fqkv.normalize_and_store(fms, stride_qkv, qkv); + perf.accum_nolock(3, t1); +#else + fqkv.normalize_and_store(fms, stride_qkv, qkv); +#endif q += q_step*stride_q; mask += q_step*stride_m; @@ -8173,6 +13949,9 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, } fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv); } +#if FA_TIMING + Perf::instance().add(perf); +#endif } // Some of the methods in FlashAttn have two identical implementations that only differ by @@ -8196,11 +13975,38 @@ struct FlashAttn { template void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv) { +// if constexpr (std::is_same_v> || std::is_same_v> || +// std::is_same_v> || +// std::is_same_v> || +// std::is_same_v> || +// std::is_same_v>) { +// compute_helper_q>( +// kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); +// } else { +// compute_helper>( +// kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); +// } if constexpr (std::is_same_v> || std::is_same_v> || - std::is_same_v> || std::is_same_v> || + std::is_same_v> || std::is_same_v>) { compute_helper_q>( kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + } + else if constexpr (std::is_same_v>) { + if (nq1 >= 8) { +#if FA_TIMING + auto t1 = Perf::cur_time(); + HelperQ80R4 khr4(nk1, kh); + Perf::instance().accum(4, t1); +#else + HelperQ80R4 khr4(nk1, kh); +#endif + compute_helper_q, VHelper, FlashQKfp32>( + khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + } else{ + compute_helper_q>( + kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); + } } else { compute_helper>( kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); @@ -8240,6 +14046,10 @@ struct HelperBF16 final : public BaseHelper { load(l1+2, vk+2*D/32); load(l1+3, vk+3*D/32); } + + inline void load_8(int l1, __m512bh * vk) const { + for (int k = 0; k < 8; ++k) load(l1 + k, vk + k*D/32); + } }; template @@ -8342,6 +14152,11 @@ struct FlashQKbf16 { } } + static inline __m128 hsum_float_4x4(__m128 * a) { + for (int i = 0; i < 2; ++i) a[i] = _mm_add_ps(_mm_unpacklo_ps(a[i], a[i+2]), _mm_unpackhi_ps(a[i], a[i+2])); + return _mm_add_ps(_mm_unpacklo_ps(a[0], a[1]), _mm_unpackhi_ps(a[0], a[1])); + } + template static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask, FlashMS& fms) { @@ -8371,8 +14186,98 @@ struct FlashQKbf16 { } } + static inline void mult_mask_kq_4(int l1, int m1, const ggml_bf16_t * q, + __m512bh * qv, const __m512bh * vkh, FlashMS& fms) { + auto qr = q + m1*D; + for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i *)qr + i)); + __m128 sum[4]; + for (int k = 0; k < 4; ++k) { + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+k*(D/32)], qv[i]); + auto aux = _mm256_add_ps(_mm512_castps512_ps256(vsum), _mm512_extractf32x8_ps(vsum, 1)); + sum[k] = _mm_add_ps(_mm256_castps256_ps128(aux), _mm256_extractf128_ps(aux, 1)); + } + //auto sum4 = _mm_mask_blend_ps(m8, hsum_float_4x4(sum), _mm_set1_ps(-INFINITY)); + //_mm_storeu_ps(fms.cache + k_step*m1 + l1, sum4); + _mm_storeu_ps(fms.cache + k_step*m1 + l1, hsum_float_4x4(sum)); + } + + static IQK_ALWAYS_INLINE __m256 hsum_float_8x8(__m256 * accm) { + for (int i = 0; i < 4; ++i) { + accm[i] = _mm256_add_ps(_mm256_permute2f128_ps(accm[i], accm[i+4], 0x20), _mm256_permute2f128_ps(accm[i], accm[i+4], 0x31)); + //accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)), + // _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1))); + } + for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2])); + return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1])); + } + + static inline void mult_mask_kq_8(int l1, int m1, const ggml_bf16_t * q, + __m512bh * qv, const __m512bh * vkh, FlashMS& fms) { + auto qr = q + m1*D; + for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i *)qr + i)); + __m256 sum[8]; + for (int k = 0; k < 8; ++k) { + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+k*(D/32)], qv[i]); + sum[k] = _mm256_add_ps(_mm512_castps512_ps256(vsum), _mm512_extractf32x8_ps(vsum, 1)); + } + _mm256_storeu_ps(fms.cache + k_step*m1 + l1, hsum_float_8x8(sum)); + } + + static inline void mult_mask_kq_one(int l1, int m1, const ggml_bf16_t * q, + __m512bh * qv, const __m512bh * vkh, FlashMS& fms) { + auto qr = q + m1*D; + for (int i = 0; i < D/32; ++i) qv[i] = __m512bh(_mm512_loadu_si512((const __m512i*)qr + i)); + auto vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i], qv[i]); + fms.cache[k_step*m1 + l1 + 0] = _mm512_reduce_add_ps(vsum); + vsum = _mm512_setzero_ps(); + for (int i = 0; i < D/32; ++i) vsum = _mm512_dpbf16_ps(vsum, vkh[i+D/32], qv[i]); + fms.cache[k_step*m1 + l1 + 1] = _mm512_reduce_add_ps(vsum); + } + +#if FA_TIMING template static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q, + const char * mask, FlashMS& fms, Perf& perf) { + auto t1 = Perf::cur_time(); +#else + template + static inline void multiply_mask_kq(const KHelper& kh, int stride_m, const ggml_bf16_t * q, + const char * mask, FlashMS& fms) { +#endif + { + __m512bh qv[D/32]; + if constexpr (D <= 128) { + __m512bh vkh[D/4]; + for (int l1 = 0; l1 < k_step; l1 += 8) { + kh.load_8(l1, vkh); + for (int j = 0; j < q_step; ++j) mult_mask_kq_8(l1, j, q, qv, vkh, fms); + } + } else { + __m512bh vkh[D/16]; + for (int l1 = 0; l1 < k_step; l1 += 2) { + kh.load_2(l1, vkh); + for (int j = 0; j < q_step; ++j) mult_mask_kq_one(l1, j, q, qv, vkh, fms); + } + } + } +#if FA_TIMING + perf.accum_nolock(1, t1); + t1 = Perf::cur_time(); +#endif + F16::Data vk[k_step/16]; + for (int j = 0; j < q_step; ++j) { + fms.update_M_S(j, vk, mask + stride_m*j); + } +#if FA_TIMING + perf.accum_nolock(2, t1); +#endif + } + + template + static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_m, const ggml_bf16_t * q, const char * mask, FlashMS& fms) { { __m512bh qv[D/32]; @@ -8380,23 +14285,19 @@ struct FlashQKbf16 { __m512bh vkh[D/8]; for (int l1 = 0; l1 < k_step; l1 += 4) { kh.load_4(l1, vkh); - for (int j = 0; j < q_step; ++j) { - mult_mask_kq_4(l1, j, stride_m, q, mask, qv, vkh, fms); - } + for (int j = 0; j < nq; ++j) mult_mask_kq_4(l1, j, q, qv, vkh, fms); } } else { __m512bh vkh[D/16]; for (int l1 = 0; l1 < k_step; l1 += 2) { kh.load_2(l1, vkh); - for (int j = 0; j < q_step; ++j) { - mult_mask_kq_one(l1, j, stride_m, q, mask, qv, vkh, fms); - } + for (int j = 0; j < nq; ++j) mult_mask_kq_one(l1, j, q, qv, vkh, fms); } } } - __m512 vk[k_step/16]; - for (int j = 0; j < q_step; ++j) { - fms.update_M_S(j, vk); + F16::Data vk[k_step/16]; + for (int j = 0; j < nq; ++j) { + fms.update_M_S(j, vk, mask + stride_m*j); } } @@ -8431,6 +14332,19 @@ struct FlashQKbf16 { bf16 += D; } } + + static inline void convert(int nq, int stride_q, const float * q, ggml_bf16_t * bf16) { + auto qr = q; + for (int j = 0; j < nq; ++j) { + for (int i = 0; i < D/32; ++i) { + auto val1 = _mm512_loadu_ps(qr + 32*i); + auto val2 = _mm512_loadu_ps(qr + 32*i + 16); + _mm512_storeu_si512((__m512i *)bf16 + i, (__m512i)_mm512_cvtne2ps_pbh(val2, val1)); + } + qr += stride_q; + bf16 += D; + } + } }; template @@ -8445,20 +14359,44 @@ struct FlashAttnBF16 { void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv) { ggml_bf16_t q_bf16[q_step*D]; +#if FA_TIMING + Perf perf(false); +#endif for (int i1 = 0; i1 < nq1/q_step; ++i1) { +#if FA_TIMING + auto t1 = Perf::cur_time(); +#endif fms.init_qstep(); kh.reset_block(); vh.reset_block(); FlashQKbf16::convert(stride_q, q, q_bf16); +#if FA_TIMING + perf.accum_nolock(0, t1); +#endif auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { +#if FA_TIMING + //t1 = Perf::cur_time(); + FlashQKbf16::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms, perf); + //perf.accum_nolock(1, t1); + t1 = Perf::cur_time(); + fqkv.accumulate_qkv(vh, fms); + perf.accum_nolock(3, t1); +#else FlashQKbf16::multiply_mask_kq(kh, stride_m, q_bf16, mr, fms); fqkv.accumulate_qkv(vh, fms); +#endif kh.next_block(); vh.next_block(); mr += k_step*sizeof(ggml_half); } +#if FA_TIMING + t1 = Perf::cur_time(); +#endif fqkv.normalize_and_store(fms, stride_qkv, qkv); +#if FA_TIMING + perf.accum_nolock(4, t1); +#endif q += q_step*stride_q; mask += q_step*stride_m; @@ -8469,9 +14407,10 @@ struct FlashAttnBF16 { fms.init_qstep(); kh.reset_block(); vh.reset_block(); + FlashQKbf16::convert(n_left, stride_q, q, q_bf16); auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { - FlashQKbf16::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms); + FlashQKbf16::multiply_mask_kq(n_left, kh, stride_m, q_bf16, mr, fms); fqkv.accumulate_qkv(n_left, vh, fms); kh.next_block(); vh.next_block(); @@ -8479,6 +14418,9 @@ struct FlashAttnBF16 { } fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv); } +#if FA_TIMING + Perf::instance().add(perf); +#endif } FlashMS fms; @@ -8486,28 +14428,58 @@ struct FlashAttnBF16 { }; #endif -template +template inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float scale, float softcap, float * qkv) { - if (nq1 >= q_step) { - FlashAttn fa(scale, softcap); + if (nk1 >= 256) { //4096) { + if (nq1 >= 64) { + FlashAttn fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + return; + } + if (nq1 >= 32) { + FlashAttn fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + return; + } + if (nq1 >= 16) { + FlashAttn fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + return; + } + } + if (nq1 >= 8) { + FlashAttn fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); - } else { + } + else { FlashAttn fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); } } #ifdef __AVX512BF16__ -template +template inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, float scale, float softcap, float * qkv) { HelperBF16 kh(k, stride_k); HelperBF16 vh(v, stride_v); - if (nq1 >= q_step) { - FlashAttnBF16 fa(scale, softcap); + if (nk1 >= 4096) { + if (nq1 >= 64) { + FlashAttnBF16 fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + return; + } + else if (nq1 >= 16) { + FlashAttnBF16 fa(scale, softcap); + fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); + return; + } + } + if (nq1 >= 8) { + FlashAttnBF16 fa(scale, softcap); fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv); } else { FlashAttnBF16 fa(scale, softcap); @@ -8516,7 +14488,7 @@ inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int } #endif -template +template inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv, const float * q, const char * v, const char * mask, @@ -8525,33 +14497,39 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, switch (type_v) { case GGML_TYPE_F16: { HelperF16 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; +#ifdef HAVE_FANCY_SIMD + case GGML_TYPE_BF16: { + HelperBF16 vh(v, stride_v); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + } break; +#endif case GGML_TYPE_Q8_0: { HelperQ80 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q4_0: { HelperQ40 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q4_1: { HelperQ41 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; case GGML_TYPE_IQ4_NL: { HelperIQ4nl vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q6_0: { HelperQ60 vh(v, stride_v); - iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; default: break; } } -template +template inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv, const float * q, const char * k, const char * v, const char * mask, @@ -8560,27 +14538,27 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, switch (type_k) { case GGML_TYPE_F16: { HelperF16 kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q8_0: { HelperQ80 kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q4_0: { HelperQ40 kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q4_1: { HelperQ41 kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; case GGML_TYPE_IQ4_NL: { HelperIQ4nl kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; case GGML_TYPE_Q6_0: { HelperQ60 kh(k, stride_k); - iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; default: break; } @@ -8627,17 +14605,33 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k stride_q /= sizeof(float); // q stride as float #ifdef __AVX512BF16__ - if (type_k == GGML_TYPE_BF16 || type_v == GGML_TYPE_BF16) { - if (type_k != GGML_TYPE_BF16 || type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 with other types + if (type_k == GGML_TYPE_BF16) { + if (nk1%64 == 0) { + if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types + switch (D) { + case 64: + iqk_flash_helper_T< 64, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 96: + iqk_flash_helper_T< 96, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 128: + iqk_flash_helper_T<128, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 256: + iqk_flash_helper_T<256, 64>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + default: + return false; + } + return true; + } + if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types switch (D) { case 64: - iqk_flash_helper_T< 64, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: - iqk_flash_helper_T<128, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 256: - iqk_flash_helper_T<256, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; default: return false; } @@ -8646,21 +14640,42 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k } #endif + if (nk1%64 == 0) { + switch (D) { + case 64: + iqk_flash_helper_T< 64, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + // Disable until we fix accumulate_qkv for odd D/16 + //case 80: + // iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 96: + iqk_flash_helper_T< 96, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + // Disable until we fix accumulate_qkv for odd D/16 + //case 112: + // iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 128: + iqk_flash_helper_T<128, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + case 256: + iqk_flash_helper_T<256, 64>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + default: + return false; + } + return true; + } switch (D) { case 64: - iqk_flash_helper_T< 64, F16::q_step, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 64, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; // Disable until we fix accumulate_qkv for odd D/16 //case 80: - // iqk_flash_helper_T< 80, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + // iqk_flash_helper_T< 80, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 96: - iqk_flash_helper_T< 96, F16::q_step, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T< 96, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; // Disable until we fix accumulate_qkv for odd D/16 //case 112: - // iqk_flash_helper_T<112, 4, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + // iqk_flash_helper_T<112, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: - iqk_flash_helper_T<128, F16::q_step, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<128, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 256: - iqk_flash_helper_T<256, F16::q_step, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; + iqk_flash_helper_T<256, 32>(type_k, type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; default: return false; } diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 306e3ff0..443fda12 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -23,6 +23,9 @@ #include #include #include +#include +#include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -364,7 +367,7 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si *s = d8[0] * (sumi[0] + sumi[1]) + d8[1] * (sumi[2] + sumi[3]) + d8[2] * (sumi[4] + sumi[5]) + d8[3] * (sumi[6] + sumi[7]); } -void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +void vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { GGML_ASSERT(nrc == 1); GGML_UNUSED(bs); @@ -522,6 +525,388 @@ void quantize_row_q8_K64(const float * x, void * y, int64_t k) { quantize_row_q8_K64_ref(x, (block_q8_K64 *)y, k); } +#ifdef __AVX2__ +namespace { +inline float hsum_float_4(__m128 x) { + x = _mm_add_ps(x, _mm_movehl_ps(x, x)); + x = _mm_add_ss(x, _mm_movehdup_ps(x)); + return _mm_cvtss_f32(x); +} +inline float hsum_float_8(__m256 x) { + return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1))); +} +inline int hsum_i32_8(const __m256i a) { + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} +inline float hmax_f32_8(__m256 x) { + __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); + max4 = _mm_max_ps( max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4)); + return _mm_cvtss_f32(max4); +} +} +#endif + +void quantize_row_q8_K16(const float * x, void * vy, int64_t nk) { + float * dptr = (float *)vy; + int8_t * qy = (int8_t *)(dptr + 5); + int n64 = nk / 64; +#ifdef z__AVX2__ + __m256 sign_bit = _mm256_set1_ps(-0.f); + __m256 vmax[4] = {}; + __m256 vsum[4] = {}; + for (int i64 = 0; i64 < n64; ++i64) { + for (int k = 0; k < 4; ++k) { + auto v1 = _mm256_loadu_ps(x + 64*i64 + 16*k + 0); + auto v2 = _mm256_loadu_ps(x + 64*i64 + 16*k + 8); + vsum[k] = _mm256_add_ps(vsum[k], _mm256_add_ps(v1, v2)); + v1 = _mm256_andnot_ps(sign_bit, v1); + v2 = _mm256_andnot_ps(sign_bit, v2); + vmax[k] = _mm256_max_ps(vmax[k], _mm256_max_ps(v1, v2)); + } + } + __m256 sum = _mm256_add_ps(_mm256_add_ps(vsum[0], vsum[1]), _mm256_add_ps(vsum[2], vsum[3])); + dptr[4] = hsum_float_8(sum); + for (int k = 0; k < 4; ++k) { + float max = hmax_f32_8(vmax[k]); + dptr[k] = max/127; + vmax[k] = _mm256_set1_ps(dptr[k] > 0 ? 1/dptr[k] : 0.f); + } + __m256i ival[8]; + const __m256i perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7); + for (int i64 = 0; i64 < n64; ++i64) { + for (int k = 0; k < 4; ++k) { + __m256 v0 = _mm256_mul_ps(vmax[k], _mm256_loadu_ps(x + 64*i64 + 16*k + 0)); + __m256 v1 = _mm256_mul_ps(vmax[k], _mm256_loadu_ps(x + 64*i64 + 16*k + 8)); + v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST); + v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST); + ival[2*k+0] = _mm256_cvtps_epi32(v0); + ival[2*k+1] = _mm256_cvtps_epi32(v1); + } + for (int k = 0; k < 2; ++k) { + auto i0 = _mm256_packs_epi32(ival[4*k+0], ival[4*k+1]); + auto i1 = _mm256_packs_epi32(ival[4*k+2], ival[4*k+3]); + i0 = _mm256_packs_epi16(i0, i1); + i0 = _mm256_permutevar8x32_epi32(i0, perm); + _mm256_storeu_si256((__m256i *)qy, i0); + qy += 32; + } + } +#elif defined z__ARM_NEON + static const uint8_t k_shuffle[16] = {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}; + auto shuffle = vld1q_u8(k_shuffle); + float32x4_t vmax[4] = {}; + float32x4_t vsum[4] = {}; + for (int i64 = 0; i64 < n64; ++i64) { + for (int k = 0; k < 4; ++k) { + auto v = vld1q_f32_x4(x + 64*i64 + 16*k); + vsum[k] = vaddq_f32(vsum[k], vaddq_f32(v.val[0], v.val[1])); + vsum[k] = vaddq_f32(vsum[k], vaddq_f32(v.val[2], v.val[3])); + vmax[k] = vmaxq_f32(vmax[k], vmaxq_f32(vabsq_f32(v.val[0]), vabsq_f32(v.val[1]))); + vmax[k] = vmaxq_f32(vmax[k], vmaxq_f32(vabsq_f32(v.val[2]), vabsq_f32(v.val[3]))); + } + } + dptr[4] = vaddvq_f32(vaddq_f32(vaddq_f32(vsum[0], vsum[1]), vaddq_f32(vsum[2], vsum[3]))); + for (int k = 0; k < 4; ++k) { + float max = vmaxvq_f32(vmax[k]); + dptr[k] = max/127; + vmax[k] = vdupq_n_f32(dptr[k] > 0 ? 1/dptr[k] : 0.f); + } + int8x16x4_t q; + for (int i64 = 0; i64 < n64; ++i64) { + for (int k = 0; k < 4; ++k) { + auto v = vld1q_f32_x4(x + 64*i64 + 16*k); + for (int j = 0; j < 4; ++j) { + q.val[j] = vreinterpretq_s8_s32(vcvtnq_s32_f32(vmulq_f32(vmax[k], v.val[j]))); + } + auto qi = vqtbl4q_s8(q, shuffle); + vst1q_s8(qy, qi); + qy += 16; + } + } +#else + float amax[4] = {0.f, 0.f, 0.f, 0.f}; + for (int i64 = 0; i64 < n64; ++i64) { + for (int k = 0; k < 4; ++k) { + for (int j = 0; j < 16; ++j) { + float ax = std::abs(x[64*i64 + 16*k + j]); + amax[k] = std::max(amax[k], ax); + } + } + } + for (int k = 0; k < 4; ++k) { + dptr[k] = amax[k]/127; + amax[k] = dptr[k] > 0 ? 1/dptr[k] : 0.f; + } + int sumi[4] = {}; + for (int i64 = 0; i64 < n64; ++i64) { + for (int k = 0; k < 4; ++k) { + for (int j = 0; j < 16; ++j) { + int ix = nearest_int(amax[k]*x[64*i64 + 16*k + j]); + sumi[k] += ix; + qy[64*i64 + 16*k + j] = ix; + } + } + } + dptr[4] = dptr[0]*sumi[0] + dptr[1]*sumi[1] + dptr[2]*sumi[2] + dptr[3]*sumi[3]; +#endif +} + +void quantize_row_q8_0_x4(const float * x, void * vy, int64_t k) { + const int nb = k / QK8_0; + const int nb4 = 4*(nb/4); + + block_q8_0 * y = (block_q8_0 *)vy; + block_q8_0_x4 * y4 = (block_q8_0_x4 *)vy; +#if defined(__aarch64__) + for (int i = 0; i < nb; i++) { + int i4 = i/4, ir = i%4; + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + if (i < nb4) { + y4[i4].d[ir] = GGML_FP32_TO_FP16(d); + } else { + y[i].d = GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < 8; j++) { + const float32x4_t v = vmulq_n_f32(srcv[j], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + if (i < nb4) { + y4[i4].qs[32*ir + 4*j + 0] = vgetq_lane_s32(vi, 0); + y4[i4].qs[32*ir + 4*j + 1] = vgetq_lane_s32(vi, 1); + y4[i4].qs[32*ir + 4*j + 2] = vgetq_lane_s32(vi, 2); + y4[i4].qs[32*ir + 4*j + 3] = vgetq_lane_s32(vi, 3); + } else { + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + } + } + } +#else + for (int i = 0; i < nb; i++) { + int i4 = i/4, ir = i%4; + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + const float d = maxScalar / 127.f; + if (i < nb4) { + y4[i4].d[ir] = GGML_FP32_TO_FP16(d); + } else { + y[i].d = GGML_FP32_TO_FP16(d); + } + const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + if (i < nb4) { + _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0); + } else { + _mm256_storeu_si256((__m256i *)y[i].qs, i0); + } + } +#endif +} + +void quantize_row_q8_1_x4(const float * x, void * vy, int64_t k) { + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; + + const int nb4 = 4*(nb/4); + block_q8_1 * y = (block_q8_1 *)vy; + block_q8_1_x4 * y4 = (block_q8_1_x4 *)vy; +#if defined(__aarch64__) + for (int i = 0; i < nb; i++) { + int i4 = i/4, ir = i%4; + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + if (i < nb4) { + y4[i4].d[ir] = GGML_FP32_TO_FP16(d); + } else { + y[i].d = GGML_FP32_TO_FP16(d); + } + + int32x4_t accv = vdupq_n_s32(0); + + for (int j = 0; j < 8; j++) { + const float32x4_t v = vmulq_n_f32(srcv[j], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + if (i < nb4) { + y4[i4].qs[QK8_1*ir + 4*j + 0] = vgetq_lane_s32(vi, 0); + y4[i4].qs[QK8_1*ir + 4*j + 1] = vgetq_lane_s32(vi, 1); + y4[i4].qs[QK8_1*ir + 4*j + 2] = vgetq_lane_s32(vi, 2); + y4[i4].qs[QK8_1*ir + 4*j + 3] = vgetq_lane_s32(vi, 3); + } else { + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + } + + accv = vaddq_s32(accv, vi); + } + + if (i < nb4) { + y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); + } else { + y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); + } + } +#else + for (int i = 0; i < nb; i++) { + int i4 = i/4, ir = i%4; + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float max_scalar = _mm_cvtss_f32( max4 ); + + // Quantize these floats + const float d = max_scalar / 127.f; + if (i < nb4) { + y4[i4].d[ir] = GGML_FP32_TO_FP16(d); + } else { + y[i].d = GGML_FP32_TO_FP16(d); + } + const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + + // Compute the sum of the quants and set y[i].s + if (i < nb4) { + y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); + } else { + y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); + } + + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + if (i < nb4) { + _mm256_storeu_si256((__m256i *)y4[i4].qs + ir, i0); + } else { + _mm256_storeu_si256((__m256i *)y[i].qs, i0); + } + } +#endif +} + // // ============================================== iq2_K // @@ -671,12 +1056,12 @@ void quantize_row_iq2_k_impl(const float * x, void * vy, int n_per_row, const fl } } -void quantize_row_iq2_k_ref(const float * GGML_RESTRICT x, block_iq2_k * GGML_RESTRICT y, int64_t k) { +void quantize_row_iq2_k_ref(const float * x, block_iq2_k * y, int64_t k) { assert(k % QK_K == 0); quantize_iq2_k(x, (void *)y, 1, k, nullptr); } -void quantize_row_iq2_k(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { +void quantize_row_iq2_k(const float * x, void * vy, int64_t k) { assert(k % QK_K == 0); block_iq2_k * y = (block_iq2_k *)vy; quantize_row_iq2_k_ref(x, y, k); @@ -694,7 +1079,7 @@ size_t quantize_iq2_k(const float * src, void * dst, int64_t nrows, int64_t n_pe return nrows * nblock * sizeof(block_iq2_k); } -void dequantize_row_iq2_k(const block_iq2_k * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +void dequantize_row_iq2_k(const block_iq2_k * x, float * y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -725,7 +1110,7 @@ void dequantize_row_iq2_k(const block_iq2_k * GGML_RESTRICT x, float * GGML_RES } -void vec_dot_iq2_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +void vec_dot_iq2_k_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); GGML_UNUSED(nrc); @@ -969,12 +1354,12 @@ void quantize_row_iq2_ks_impl(const float * x, void * vy, int n_per_row, const f } } -void quantize_row_iq2_ks_ref(const float * GGML_RESTRICT x, block_iq2_ks * GGML_RESTRICT y, int64_t k) { +void quantize_row_iq2_ks_ref(const float * x, block_iq2_ks * y, int64_t k) { assert(k % QK_K == 0); quantize_iq2_ks(x, (void *)y, 1, k, nullptr); } -void quantize_row_iq2_ks(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { +void quantize_row_iq2_ks(const float * x, void * vy, int64_t k) { assert(k % QK_K == 0); block_iq2_ks * y = (block_iq2_ks *)vy; quantize_row_iq2_ks_ref(x, y, k); @@ -996,7 +1381,7 @@ size_t quantize_iq2_ks(const float * src, void * dst, int64_t nrows, int64_t n_p return nrows * row_size; } -void dequantize_row_iq2_ks(const block_iq2_ks * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +void dequantize_row_iq2_ks(const block_iq2_ks * x, float * y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -1336,7 +1721,7 @@ void dequantize_row_iq3_k(const block_iq3_k * x, float * y, int64_t k) { } } -void vec_dot_iq3_k_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +void vec_dot_iq3_k_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); GGML_UNUSED(nrc); @@ -2341,25 +2726,8 @@ size_t quantize_iq6_k(const float * src, void * dst, int64_t nrows, int64_t n_pe return nrows * nblock * sizeof(block_iq6_k); } -#ifdef __AVX2__ -namespace { -inline int hsum_i32_8(const __m256i a) { - const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); - const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); - const __m128i sum64 = _mm_add_epi32(hi64, sum128); - const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); - return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); -} -inline float hmax_f32_8(__m256 x) { - __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); - max4 = _mm_max_ps( max4, _mm_movehl_ps(max4, max4)); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4)); - return _mm_cvtss_f32(max4); -} -} -#endif - -void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) { +template +void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; block_q8_K * y = (block_q8_K *)vy; @@ -2394,8 +2762,14 @@ void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) { __m256i i1 = _mm256_cvtps_epi32(v1); __m256i i2 = _mm256_cvtps_epi32(v2); __m256i i3 = _mm256_cvtps_epi32(v3); - y[i].bsums[2*ib+0] = hsum_i32_8(_mm256_add_epi32(i0, i1)); - y[i].bsums[2*ib+1] = hsum_i32_8(_mm256_add_epi32(i2, i3)); + if constexpr (q8_type > 0) { + int bsum = hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); + auto bs = (float *)y[i].bsums; + bs[ib] = d*bsum; + } else { + y[i].bsums[2*ib+0] = hsum_i32_8(_mm256_add_epi32(i0, i1)); + y[i].bsums[2*ib+1] = hsum_i32_8(_mm256_add_epi32(i2, i3)); + } i0 = _mm256_packs_epi32( i0, i1 ); i2 = _mm256_packs_epi32( i2, i3 ); i0 = _mm256_packs_epi16( i0, i2 ); @@ -2403,6 +2777,12 @@ void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) { _mm256_storeu_si256((__m256i *)q8, i0); q8 += 32; } + if constexpr (q8_type == 2) { + auto bs = (float *)y[i].bsums; + float sum = 0; + for (int ib = 0; ib < QK_K/32; ++ib) sum += bs[ib]; + bs[0] = sum; + } } #else for (int i = 0; i < nb; i++) { @@ -2428,12 +2808,29 @@ void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) { int v = nearest_int(iscale*x[j]); y[i].qs[j] = MIN(127, v); } - for (int j = 0; j < QK_K/16; ++j) { - int sum = 0; - for (int ii = 0; ii < 16; ++ii) { - sum += y[i].qs[j*16 + ii]; + if constexpr (q8_type > 0) { + auto bs = (float *)y[i].bsums; + float d = 1/iscale; + float sum = 0; + for (int j = 0; j < QK_K/32; ++j) { + int sum = 0; + for (int ii = 0; ii < 32; ++ii) { + sum += y[i].qs[j*32 + ii]; + } + bs[j] = d*sum; + sum += bs[j]; + } + if constexpr (q8_type == 2) { + bs[0] = sum; + } + } else { + for (int j = 0; j < QK_K/16; ++j) { + int sum = 0; + for (int ii = 0; ii < 16; ++ii) { + sum += y[i].qs[j*16 + ii]; + } + y[i].bsums[j] = sum; } - y[i].bsums[j] = sum; } y[i].d = 1/iscale; x += QK_K; @@ -2442,6 +2839,18 @@ void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) { } +void iqk_quantize_row_q8_K(const float * x, void * vy, int64_t k) { + iqk_quantize_row_q8_K_T<0>(x, vy, k); +} + +void quantize_row_q8_K32(const float * x, void * vy, int64_t k) { + iqk_quantize_row_q8_K_T<1>(x, vy, k); +} + +void quantize_row_q8_KR8(const float * x, void * vy, int64_t k) { + iqk_quantize_row_q8_K_T<2>(x, vy, k); +} + namespace { static void quantize_row_iq4_k_impl_bs128(const int super_block_size, const int block_size, int n_per_row, const float * x, char * cy, @@ -3121,18 +3530,2627 @@ void vec_dot_iq4_kss_q8_k(int n, float * s, size_t bs, const void * vx, size_t b GGML_UNUSED(by); } -// ========================================== iq2_kt ==================================================== +// +// ========================================= iq4_nl_r4 +// +void quantize_row_iq4_nl_r4_ref(const float * x, block_iq4_nl_r4 * y, int64_t k) { + // we assume we are called with 4 rows + quantize_iq4_nl_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_iq4_nl_r4(const float * x, void * y, int64_t k) { + // we assume we are called with 4 rows + quantize_iq4_nl_r4(x, y, 4, k/4, nullptr); +} + +static void repack_iq4_nl(int nrows, int n_per_row, const block_iq4_nl * x, block_iq4_nl_r4 * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK4_NL == 0); + int nblock = n_per_row/QK4_NL; + const block_iq4_nl * x4[4]; + for (int row = 0; row < nrows; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + for (int ib = 0; ib < nblock; ++ib) { + for (int k = 0; k < 4; ++k) y[ib].d[k] = x4[k][ib].d; + for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) { + y[ib].qs[4*k+i+ 0] = (x4[k][ib].qs[i+0] & 0xf) | ((x4[k][ib].qs[i+ 8] & 0x0f) << 4); // 0....3 + 8...11 from each row + y[ib].qs[4*k+i+16] = (x4[k][ib].qs[i+0] >> 4) | ((x4[k][ib].qs[i+ 8] & 0xf0)); // 16...19 + 24...27 from each row + y[ib].qs[4*k+i+32] = (x4[k][ib].qs[i+4] & 0xf) | ((x4[k][ib].qs[i+12] & 0x0f) << 4); // 4....7 + 12...15 from each row + y[ib].qs[4*k+i+48] = (x4[k][ib].qs[i+4] >> 4) | ((x4[k][ib].qs[i+12] & 0xf0)); // 20...23 + 28...31 from each row + } + } + x += 4*nblock; + y += nblock; + } +} + +size_t quantize_iq4_nl_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + auto row_size_nl = ggml_row_size(GGML_TYPE_IQ4_NL, n_per_row); + std::vector qtmp(4*row_size_nl); + char * qrow = (char *)dst; + for (int row = 0; row < nrows; row += 4) { + quantize_iq4_nl(src, qtmp.data(), 4, n_per_row, imatrix); + repack_iq4_nl(4, n_per_row, (const block_iq4_nl *)qtmp.data(), (block_iq4_nl_r4 *)qrow); + src += 4*n_per_row; + qrow += 4*row_size_nl; + } + return nrows*row_size_nl; +} + +void dequantize_row_iq4_nl_r4(const block_iq4_nl_r4 * x, float * y, int64_t k) { + // we assume we are called with 4 rows + int n_per_row = k/4; + int nb = n_per_row/QK4_NL; + float * yk[4]; + for (int k = 0; k < 4; ++k) yk[k] = y + k*n_per_row; + for (int ib = 0; ib < nb; ++ib) { + for (int k = 0; k < 4; ++k) { + float scale = GGML_FP16_TO_FP32(x[ib].d[k]); + for (int i = 0; i < 4; ++i) { + yk[k][QK4_NL*ib+i+ 0] = scale * iq4k_values[x[ib].qs[4*k+i+ 0] & 0xf]; + yk[k][QK4_NL*ib+i+ 8] = scale * iq4k_values[x[ib].qs[4*k+i+ 0] >> 4]; + yk[k][QK4_NL*ib+i+16] = scale * iq4k_values[x[ib].qs[4*k+i+16] & 0xf]; + yk[k][QK4_NL*ib+i+24] = scale * iq4k_values[x[ib].qs[4*k+i+16] >> 4]; + yk[k][QK4_NL*ib+i+ 4] = scale * iq4k_values[x[ib].qs[4*k+i+32] & 0xf]; + yk[k][QK4_NL*ib+i+12] = scale * iq4k_values[x[ib].qs[4*k+i+32] >> 4]; + yk[k][QK4_NL*ib+i+20] = scale * iq4k_values[x[ib].qs[4*k+i+48] & 0xf]; + yk[k][QK4_NL*ib+i+28] = scale * iq4k_values[x[ib].qs[4*k+i+48] >> 4]; + } + } + } +} + +void vec_dot_iq4_nl_r4_q8_0(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ4_NL_R4, vx, 0, GGML_TYPE_Q8_0, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +// +// ========================================= q4_0_r4 +// +void quantize_row_q4_0_r4_ref(const float * x, block_iq4_nl_r4 * y, int64_t k) { + // we assume we are called with 4 rows + quantize_q4_0_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_q4_0_r4(const float * x, void * y, int64_t k) { + // we assume we are called with 4 rows + quantize_q4_0_r4(x, y, 4, k/4, nullptr); +} + +static void repack_q4_0(int nrows, int n_per_row, const block_q4_0 * x, block_iq4_nl_r4 * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK4_NL == 0); + int nblock = n_per_row/QK4_NL; + const block_q4_0 * x4[4]; + for (int row = 0; row < nrows; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + for (int ib = 0; ib < nblock; ++ib) { + //for (int k = 0; k < 4; ++k) y[ib].d[k] = x4[k][ib].d; + //for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) { + // y[ib].qs[4*k+i+ 0] = (x4[k][ib].qs[i+0] & 0xf) | ((x4[k][ib].qs[i+ 8] & 0x0f) << 4); // 0....3 + 8...11 from each row + // y[ib].qs[4*k+i+16] = (x4[k][ib].qs[i+0] >> 4) | ((x4[k][ib].qs[i+ 8] & 0xf0)); // 16...19 + 24...27 from each row + // y[ib].qs[4*k+i+32] = (x4[k][ib].qs[i+4] & 0xf) | ((x4[k][ib].qs[i+12] & 0x0f) << 4); // 4....7 + 12...15 from each row + // y[ib].qs[4*k+i+48] = (x4[k][ib].qs[i+4] >> 4) | ((x4[k][ib].qs[i+12] & 0xf0)); // 20...23 + 28...31 from each row + //} + for (int k = 0; k < 4; ++k) { + y[ib].d[k] = x4[k][ib].d; + for (int l = 0; l < 4; ++l) { + // l = 0 -> 0, 8 with shift 0 -> 4*(l/2), 4*(l/2)+8 with shift 4*(l%2) + // l = 1 -> 0, 8 with shift 4 + // l = 2 -> 4, 12 with shift 0 + // l = 3 -> 4, 12 with shift 4 + for (int i = 0; i < 4; ++i) { + y[ib].qs[4*k+i+16*l] = ((x4[k][ib].qs[i+4*(l/2)] >> 4*(l%2)) & 0xf) | (((x4[k][ib].qs[i+4*(l/2)+8] >> 4*(l%2)) & 0xf) << 4); + } + } + } + } + x += 4*nblock; + y += nblock; + } +} + +size_t quantize_q4_0_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + auto row_size_nl = ggml_row_size(GGML_TYPE_IQ4_NL, n_per_row); + std::vector qtmp(4*row_size_nl); + char * qrow = (char *)dst; + for (int row = 0; row < nrows; row += 4) { + quantize_q4_0(src, qtmp.data(), 4, n_per_row, imatrix); + repack_iq4_nl(4, n_per_row, (const block_iq4_nl *)qtmp.data(), (block_iq4_nl_r4 *)qrow); + src += 4*n_per_row; + qrow += 4*row_size_nl; + } + return nrows*row_size_nl; +} + +void dequantize_row_q4_0_r4(const block_iq4_nl_r4 * x, float * y, int64_t k) { + // we assume we are called with 4 rows + int n_per_row = k/4; + int nb = n_per_row/QK4_0; + float * yk[4]; + for (int k = 0; k < 4; ++k) yk[k] = y + k*n_per_row; + for (int ib = 0; ib < nb; ++ib) { + for (int k = 0; k < 4; ++k) { + float scale = GGML_FP16_TO_FP32(x[ib].d[k]); + for (int l = 0; l < 4; ++l) { + int ll = 16*(l%2) + 4*(l/2); + for (int i = 0; i < 4; ++i) { + yk[k][QK4_0*ib+i+ll+0] = scale * ((x[ib].qs[4*k+i+16*l] & 0xf) - 8); + yk[k][QK4_0*ib+i+ll+8] = scale * ((x[ib].qs[4*k+i+16*l] >> 4) - 8); + } + } + } + } +} + +void vec_dot_q4_0_r4_q8_0(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q4_0_R4, vx, 0, GGML_TYPE_Q8_0, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + + +// +// ========================================= q8_0_r4 +// +void quantize_row_q8_0_r4_ref(const float * x, block_q8_0_x4 * y, int64_t k) { + // we assume we are called with 4 rows + quantize_q8_0_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_q8_0_r4(const float * x, void * y, int64_t k) { + // we assume we are called with 4 rows + quantize_q8_0_r4(x, y, 4, k/4, nullptr); +} + +static void repack_q8_0(int nrows, int n_per_row, const block_q8_0 * x, block_q8_0_x4 * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK8_0 == 0); + int nblock = n_per_row/QK8_0; + const block_q8_0 * x4[4]; + for (int row = 0; row < nrows; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + for (int ib = 0; ib < nblock; ++ib) { + for (int k = 0; k < 4; ++k) y[ib].d[k] = x4[k][ib].d; + for (int l = 0; l < 4; ++l) { + for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) { + y[ib].qs[32*l+4*k+i+ 0] = x4[k][ib].qs[i+4*l+ 0]; + y[ib].qs[32*l+4*k+i+16] = x4[k][ib].qs[i+4*l+16]; + } + } + } + x += 4*nblock; + y += nblock; + } +} + +size_t quantize_q8_0_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + auto row_size_0 = ggml_row_size(GGML_TYPE_Q8_0, n_per_row); + std::vector qtmp(4*row_size_0); + char * qrow = (char *)dst; + for (int row = 0; row < nrows; row += 4) { + quantize_q8_0(src, qtmp.data(), 4, n_per_row, imatrix); + repack_q8_0(4, n_per_row, (const block_q8_0 *)qtmp.data(), (block_q8_0_x4 *)qrow); + src += 4*n_per_row; + qrow += 4*row_size_0; + } + return nrows*row_size_0; +} + +void dequantize_row_q8_0_r4(const block_q8_0_x4 * x, float * y, int64_t k) { + // we assume we are called with 4 rows + int n_per_row = k/4; + int nb = n_per_row/QK8_0; + float * yk[4]; + for (int k = 0; k < 4; ++k) yk[k] = y + k*n_per_row; + for (int ib = 0; ib < nb; ++ib) { + for (int k = 0; k < 4; ++k) { + float scale = GGML_FP16_TO_FP32(x[ib].d[k]); + for (int l = 0; l < 4; ++l) for (int i = 0; i < 4; ++i) { + yk[k][QK8_0*ib+4*l+i+ 0] = scale * x[ib].qs[QK8_0*l+4*k+i+ 0]; + yk[k][QK8_0*ib+4*l+i+16] = scale * x[ib].qs[QK8_0*l+4*k+i+16]; + } + } + } +} + +void vec_dot_q8_0_r4_q8_0(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q8_0_R4, vx, 0, GGML_TYPE_Q8_0, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +// +// ========================================= q5_0_r4 +// +void quantize_row_q5_0_r4_ref(const float * x, block_q5_0_r4 * y, int64_t k) { + // we assume we are called with 4 rows + quantize_q5_0_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_q5_0_r4(const float * x, void * y, int64_t k) { + // we assume we are called with 4 rows + quantize_q5_0_r4(x, y, 4, k/4, nullptr); +} + +static inline void convert_q5_0(const block_q5_0& x, uint8_t * L) { + uint32_t qh; + memcpy(&qh, x.qh, sizeof(qh)); + + for (int j = 0; j < QK5_0/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + L[j ] = (x.qs[j] & 0x0F) | xh_0; + L[j + QK4_0/2] = (x.qs[j] >> 4) | xh_1; + } +} + +static void repack_q5_0(int nrows, int n_per_row, const block_q5_0 * x, block_q5_0_r4 * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK5_0 == 0); + int nblock = n_per_row/QK5_0; + const block_q5_0 * x4[4]; + uint8_t L[QK5_0]; + for (int row = 0; row < nrows; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + for (int ib = 0; ib < nblock; ++ib) { + std::memset(y[ib].qh, 0, QK5_0/2); + for (int k = 0; k < 4; ++k) { + y[ib].d[k] = x4[k][ib].d; + convert_q5_0(x4[k][ib], L); + for (int l = 0; l < 4; ++l) { + int l1 = 4*(l/2) + 16*(l%2), l2 = l1 + 8; + for (int i = 0; i < 4; ++i) { + y[ib].qs[4*k+i+16*l] = (L[i + l1] & 0xf) | ((L[i + l2] & 0xf) << 4); + y[ib].qh[4*k+i] |= ((L[i + l1] >> 4) | ((L[i + l2] >> 4) << 4)) << l; + } + } + } + } + x += 4*nblock; + y += nblock; + } +} + +size_t quantize_q5_0_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + auto row_size_0 = ggml_row_size(GGML_TYPE_Q5_0, n_per_row); + std::vector qtmp(4*row_size_0); + char * qrow = (char *)dst; + for (int row = 0; row < nrows; row += 4) { + quantize_q5_0(src, qtmp.data(), 4, n_per_row, imatrix); + repack_q5_0(4, n_per_row, (const block_q5_0 *)qtmp.data(), (block_q5_0_r4 *)qrow); + src += 4*n_per_row; + qrow += 4*row_size_0; + } + return nrows*row_size_0; +} + +void dequantize_row_q5_0_r4(const block_q5_0_r4 * x, float * y, int64_t k) { + // we assume we are called with 4 rows + int n_per_row = k/4; + int nb = n_per_row/QK8_0; + float * yk[4]; + for (int k = 0; k < 4; ++k) yk[k] = y + k*n_per_row; + for (int ib = 0; ib < nb; ++ib) { + for (int k = 0; k < 4; ++k) { + float d = GGML_FP16_TO_FP32(x[ib].d[k]); + float m = -16*d; + for (int l = 0; l < 4; ++l) { + int ll = 16*(l%2) + 4*(l/2); + for (int i = 0; i < 4; ++i) { + yk[k][QK4_0*ib+i+ll+0] = d * ((x[ib].qs[4*k+i+16*l] & 0xf) | (((x[ib].qh[4*k+i] >> (l+0)) & 1) << 4)) + m; + yk[k][QK4_0*ib+i+ll+8] = d * ((x[ib].qs[4*k+i+16*l] >> 4) | (((x[ib].qh[4*k+i] >> (l+4)) & 1) << 4)) + m; + } + } + } + } +} + +void vec_dot_q5_0_r4_q8_0(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q5_0_R4, vx, 0, GGML_TYPE_Q8_0, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +// +// ========================================= q6_0_r4 +// +void quantize_row_q6_0_r4_ref(const float * x, block_q6_0_r4 * y, int64_t k) { + // we assume we are called with 4 rows + quantize_q6_0_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_q6_0_r4(const float * x, void * y, int64_t k) { + // we assume we are called with 4 rows + quantize_q6_0_r4(x, y, 4, k/4, nullptr); +} + +static inline void convert_q6_0(const block_q6_0& x, uint8_t * L) { + + for (int j = 0; j < QK6_0/2; ++j) { + const uint8_t h = x.qh[j%(QK6_0/4)] >> 4*(j/(QK6_0/4)); + L[j ] = (x.qs[j] & 0x0F) | ((h << 4) & 0x30); + L[j + QK6_0/2] = (x.qs[j] >> 4) | ((h << 2) & 0x30); + } +} + +static void repack_q6_0(int nrows, int n_per_row, const block_q6_0 * x, block_q6_0_r4 * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK5_0 == 0); + int nblock = n_per_row/QK6_0; + const block_q6_0 * x4[4]; + uint8_t L[QK6_0]; + for (int row = 0; row < nrows; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + for (int ib = 0; ib < nblock; ++ib) { + std::memset(y[ib].qh, 0, QK6_0); + for (int k = 0; k < 4; ++k) { + y[ib].d[k] = x4[k][ib].d; + convert_q6_0(x4[k][ib], L); + for (int l = 0; l < 4; ++l) { + int l1 = 4*(l/2) + 16*(l%2), l2 = l1 + 8; + for (int i = 0; i < 4; ++i) { + y[ib].qs[4*k+i+16*l] = (L[i + l1] & 0xf) | ((L[i + l2] & 0xf) << 4); + y[ib].qh[4*k+i+16*(l%2)] |= ((L[i + l1] >> 4) | ((L[i + l2] >> 4) << 4)) << 2*(l/2); + } + } + } + } + x += 4*nblock; + y += nblock; + } +} + +size_t quantize_q6_0_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + auto row_size_0 = ggml_row_size(GGML_TYPE_Q6_0, n_per_row); + std::vector qtmp(4*row_size_0); + char * qrow = (char *)dst; + for (int row = 0; row < nrows; row += 4) { + quantize_q6_0(src, qtmp.data(), 4, n_per_row, imatrix); + repack_q6_0(4, n_per_row, (const block_q6_0 *)qtmp.data(), (block_q6_0_r4 *)qrow); + src += 4*n_per_row; + qrow += 4*row_size_0; + } + return nrows*row_size_0; +} + +void dequantize_row_q6_0_r4(const block_q6_0_r4 * x, float * y, int64_t k) { + // we assume we are called with 4 rows + int n_per_row = k/4; + int nb = n_per_row/QK6_0; + float * yk[4]; + for (int k = 0; k < 4; ++k) yk[k] = y + k*n_per_row; + for (int ib = 0; ib < nb; ++ib) { + for (int k = 0; k < 4; ++k) { + float d = GGML_FP16_TO_FP32(x[ib].d[k]); + float m = -32*d; + for (int l = 0; l < 4; ++l) { + int ll = 16*(l%2) + 4*(l/2); + for (int i = 0; i < 4; ++i) { + yk[k][QK4_0*ib+i+ll+0] = d * ((x[ib].qs[4*k+i+16*l] & 0xf) | (((x[ib].qh[4*k+i+16*(l%2)] >> (2*(l/2)+0)) & 3) << 4)) + m; + yk[k][QK4_0*ib+i+ll+8] = d * ((x[ib].qs[4*k+i+16*l] >> 4) | (((x[ib].qh[4*k+i+16*(l%2)] >> (2*(l/2)+4)) & 3) << 4)) + m; + } + } + } + } +} + +void vec_dot_q6_0_r4_q8_0(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q6_0_R4, vx, 0, GGML_TYPE_Q8_0, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +// +// ========================================= iq4_xs_r4 +// + +void quantize_row_iq4_xs_r4_ref(const float * x, block_iq4_xs_r4 * y, int64_t k) { + quantize_iq4_xs_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_iq4_xs_r4(const float * x, void * y, int64_t k) { + quantize_iq4_xs_r4(x, y, 4, k/4, nullptr); +} + +static void repack_iq4_xs(int nrows, int n_per_row, const block_iq4_xs * x, block_iq4_xs_r4 * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + const block_iq4_xs * x4[4]; + for (int row = 0; row < nrows; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + for (int ibl = 0; ibl < nblock; ++ibl) { + std::memset(y[ibl].scales_l, 0, QK_K/16); + std::memset(y[ibl].scales_h, 0, QK_K/32); + for (int k = 0; k < 4; ++k) { + y[ibl].d[k] = x4[k][ibl].d; + for (int ib = 0; ib < QK_K/32; ++ib) { + uint8_t sl = (x4[k][ibl].scales_l[ib/2] >> 4*(ib%2)) & 0xf; + uint8_t sh = (x4[k][ibl].scales_h >> 2*ib) & 3; + int i = 4*ib + k; + y[ibl].scales_l[i%16] |= (sl << 4*(i/16)); + y[ibl].scales_h[i%8 ] |= (sh << 2*(i/8)); + } + } + for (int ib = 0; ib < QK_K/32; ++ib) { + for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) { + y[ibl].qs[64*ib+4*k+i+ 0] = (x4[k][ibl].qs[16*ib+i+0] & 0xf) | ((x4[k][ibl].qs[16*ib+i+ 8] & 0x0f) << 4); // 0....3 + 8...11 from each row + y[ibl].qs[64*ib+4*k+i+16] = (x4[k][ibl].qs[16*ib+i+0] >> 4) | ((x4[k][ibl].qs[16*ib+i+ 8] & 0xf0)); // 16...19 + 24...27 from each row + y[ibl].qs[64*ib+4*k+i+32] = (x4[k][ibl].qs[16*ib+i+4] & 0xf) | ((x4[k][ibl].qs[16*ib+i+12] & 0x0f) << 4); // 4....7 + 12...15 from each row + y[ibl].qs[64*ib+4*k+i+48] = (x4[k][ibl].qs[16*ib+i+4] >> 4) | ((x4[k][ibl].qs[16*ib+i+12] & 0xf0)); // 20...23 + 28...31 from each row + } + } + } + x += 4*nblock; + y += nblock; + } +} + +size_t quantize_iq4_xs_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + char * qcur = (char *)dst; + auto row_size = ggml_row_size(GGML_TYPE_IQ4_XS, n_per_row); + std::vector qtmp(4*row_size); + for (int row = 0; row < nrows; row += 4) { + quantize_iq4_xs(src, (void *)qtmp.data(), 4, n_per_row, imatrix); + repack_iq4_xs(4, n_per_row, (const block_iq4_xs *)qtmp.data(), (block_iq4_xs_r4 *)qcur); + qcur += 4*row_size; + src += 4*n_per_row; + } + return nrows*row_size; +} + +void dequantize_row_iq4_xs_r4(const block_iq4_xs_r4 * x, float * y, int64_t k) { + auto n_per_row = k/4; + float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row}; + int nblock = n_per_row/QK_K; + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + const float d = GGML_FP16_TO_FP32(x[ibl].d[k]); + for (int ib = 0; ib < QK_K/32; ++ib) { + int is = 4*ib + k; + float dl = d * ((((x[ibl].scales_l[is%16] >> 4*(is/16)) & 0xf) | (((x[ibl].scales_h[is%8] >> 2*(is/8)) & 3) << 4)) - 32); + for (int i = 0; i < 4; ++i) { + y4[k][QK_K*ibl+32*ib+i+ 0] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+ 0] & 0xf]; + y4[k][QK_K*ibl+32*ib+i+ 8] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+ 0] >> 4]; + y4[k][QK_K*ibl+32*ib+i+16] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+16] & 0xf]; + y4[k][QK_K*ibl+32*ib+i+24] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+16] >> 4]; + y4[k][QK_K*ibl+32*ib+i+ 4] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+32] & 0xf]; + y4[k][QK_K*ibl+32*ib+i+12] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+32] >> 4]; + y4[k][QK_K*ibl+32*ib+i+20] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+48] & 0xf]; + y4[k][QK_K*ibl+32*ib+i+28] = dl * iq4k_values[x[ibl].qs[64*ib+4*k+i+48] >> 4]; + } + } + } + //dequantize_row_iq4_xs(x + ib, ytmp, QK_K); + //for (int k = 0; k < 4; ++k) { + // for (int l = 0; l < 16; ++l) { + // for (int i = 0; i < 4; ++i) { + // //y4[k][ib*kBlockSize + i + 16*(l%4) + 4*(l/4)] = ytmp[16*l + 4*k + i]; + // y4[k][ib*kBlockSize + i + 8*(l%8) + 4*(l/8)] = ytmp[16*l + 4*k + i]; + // } + // } + //} + } +} + +void vec_dot_iq4_xs_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ4_XS_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +// +// ========================================= iq4_ks_r4 +// + +void quantize_row_iq4_ks_r4_ref(const float * x, block_iq4_ks_r4 * y, int64_t k) { + quantize_iq4_ks_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_iq4_ks_r4(const float * x, void * y, int64_t k) { + quantize_iq4_ks_r4(x, y, 4, k/4, nullptr); +} + +static void repack_iq4_ks(int nrows, int n_per_row, const block_iq4_ks * x, block_iq4_ks_r4 * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + auto row_size = ggml_row_size(GGML_TYPE_IQ4_KS, n_per_row); + int nblock = n_per_row/QK_K; + char * cy = (char *)y; + const char * cx = (const char *)x; + const block_iq4_ks * x4[4]; + for (int row = 0; row < nrows; row += 4) { + float * dptr = (float *)cy; + block_iq4_ks_r4 * y = (block_iq4_ks_r4 *)(dptr + 4); + for (int k = 0; k < 4; ++k) { + auto dk = (const float *)(cx + k*row_size); + dptr[k] = dk[0]; + x4[k] = (const block_iq4_ks *)(dk + 1); + } + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + for (int ib = 0; ib < QK_K/32; ++ib) { + y[ibl].scales[4*ib+k] = x4[k][ibl].scales[ib]; + for (int i = 0; i < 4; ++i) { + y[ibl].qs[64*ib+4*k+i+ 0] = (x4[k][ibl].qs[16*ib+i+0] & 0xf) | ((x4[k][ibl].qs[16*ib+i+ 8] & 0x0f) << 4); // 0....3 + 8...11 from each row + y[ibl].qs[64*ib+4*k+i+16] = (x4[k][ibl].qs[16*ib+i+0] >> 4) | ((x4[k][ibl].qs[16*ib+i+ 8] & 0xf0)); // 16...19 + 24...27 from each row + y[ibl].qs[64*ib+4*k+i+32] = (x4[k][ibl].qs[16*ib+i+4] & 0xf) | ((x4[k][ibl].qs[16*ib+i+12] & 0x0f) << 4); // 4....7 + 12...15 from each row + y[ibl].qs[64*ib+4*k+i+48] = (x4[k][ibl].qs[16*ib+i+4] >> 4) | ((x4[k][ibl].qs[16*ib+i+12] & 0xf0)); // 20...23 + 28...31 from each row + } + } + } + } + cx += 4*row_size; + cy += 4*row_size; + } +} + +size_t quantize_iq4_ks_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + char * qcur = (char *)dst; + auto row_size = ggml_row_size(GGML_TYPE_IQ4_KS, n_per_row); + std::vector qtmp(4*row_size); + for (int row = 0; row < nrows; row += 4) { + quantize_iq4_ks(src, (void *)qtmp.data(), 4, n_per_row, imatrix); + repack_iq4_ks(4, n_per_row, (const block_iq4_ks *)qtmp.data(), (block_iq4_ks_r4 *)qcur); + qcur += 4*row_size; + src += 4*n_per_row; + } + return nrows*row_size; +} + +void dequantize_row_iq4_ks_r4(const block_iq4_ks_r4 * x, float * y, int64_t k) { + auto n_per_row = k/4; + float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row}; + int nblock = n_per_row/QK_K; + const float * dptr = (const float *)x; + x = (const block_iq4_ks_r4 *)(dptr + 4); + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + const float d = dptr[k]; + for (int ib = 0; ib < QK_K/32; ++ib) { + float dl = d * ((x[ibl].scales[4*ib + k] & 254) - 127); + auto values = iq4k_values + ((x[ibl].scales[4*ib + k] & 1) << 4); + for (int i = 0; i < 4; ++i) { + y4[k][QK_K*ibl+32*ib+i+ 0] = dl * values[x[ibl].qs[64*ib+4*k+i+ 0] & 0xf]; + y4[k][QK_K*ibl+32*ib+i+ 8] = dl * values[x[ibl].qs[64*ib+4*k+i+ 0] >> 4]; + y4[k][QK_K*ibl+32*ib+i+16] = dl * values[x[ibl].qs[64*ib+4*k+i+16] & 0xf]; + y4[k][QK_K*ibl+32*ib+i+24] = dl * values[x[ibl].qs[64*ib+4*k+i+16] >> 4]; + y4[k][QK_K*ibl+32*ib+i+ 4] = dl * values[x[ibl].qs[64*ib+4*k+i+32] & 0xf]; + y4[k][QK_K*ibl+32*ib+i+12] = dl * values[x[ibl].qs[64*ib+4*k+i+32] >> 4]; + y4[k][QK_K*ibl+32*ib+i+20] = dl * values[x[ibl].qs[64*ib+4*k+i+48] & 0xf]; + y4[k][QK_K*ibl+32*ib+i+28] = dl * values[x[ibl].qs[64*ib+4*k+i+48] >> 4]; + } + } + } + } +} + +void vec_dot_iq4_ks_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ4_KS_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +// +// ========================================= iq2_bn_r4 +// +void quantize_row_iq2_bn_r4_ref(const float * x, block_iq2_bn * y, int64_t k) { + quantize_iq2_bn_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_iq2_bn_r4(const float * x, void * y, int64_t k) { + quantize_iq2_bn_r4(x, y, 4, k/4, nullptr); +} + +namespace { +void repack_iq2_bn(int nrows, int n_per_row, const char * x, char * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_IQ1BN == 0); + int nblock = n_per_row/QK_IQ1BN; + auto row_size = ggml_row_size(GGML_TYPE_IQ2_BN, n_per_row); + const uint8_t * x4[4]; + for (int row = 0; row < nrows; row += 4) { + float * dr4 = (float *)(y + 4*row*row_size); + for (int k = 0; k < 4; ++k) { + const float * dptr = (const float *)(x + (row + k)*row_size); + dr4[k] = *dptr; + x4[k] = (const uint8_t *)(dptr + 1); + } + uint8_t * y4 = (uint8_t *)(dr4 + 4); + //std::memset(y4, 0, n_per_row); + for (int ib = 0; ib < nblock; ++ib) { + // 0...3 from rows 0...3 go to 1st 2 bits of 0...15 + // 16..19 from rows 0...3 go to 1st 2 bits of 16...31 + // 32..35 from rows 0...3 go to 1st 2 bits of 32...47 + // 48..51 from rows 0...3 go to 1st 2 bits of 48...63 + // 4...7 from rows 0...3 go to 2nd 2 bits of 0...15 + // 20..23 from rows 0...3 go to 2nd 2 bits of 16...31 + // 36..39 from rows 0...3 go to 2nd 2 bits of 32...47 + // 52..55 from rows 0...3 go to 2nd 2 bits of 48...63 + // 8..11 from rows 0...3 go to 3rd 2 bits of 0...15 + // 24..27 from rows 0...3 go to 3rd 2 bits of 16...31 + // 40..43 from rows 0...3 go to 3rd 2 bits of 32...47 + // 56..59 from rows 0...3 go to 3rd 2 bits of 48...63 + // 12..15 from rows 0...3 go to 4th 2 bits of 0...15 + // 28..31 from rows 0...3 go to 4th 2 bits of 16...31 + // 44..47 from rows 0...3 go to 4th 2 bits of 32...47 + // 60..63 from rows 0...3 go to 4th 2 bits of 48...63 + for (int k = 0; k < 4; ++k) { + for (int l = 0; l < 4; ++l) for (int i = 0; i < 4; ++i) { + y4[64*ib + 4*k + i + 16*l] = (((x4[k][16*ib + i + 0] >> 2*l) & 3) << 0) | + (((x4[k][16*ib + i + 4] >> 2*l) & 3) << 2) | + (((x4[k][16*ib + i + 8] >> 2*l) & 3) << 4) | + (((x4[k][16*ib + i + 12] >> 2*l) & 3) << 6); + //y4[64*ib + 4*k + i + 0] |= (x4[k][16*ib + i] >> 0) & 3; + //y4[64*ib + 4*k + i + 16] |= (x4[k][16*ib + i] >> 2) & 3; + //y4[64*ib + 4*k + i + 32] |= (x4[k][16*ib + i] >> 4) & 3; + //y4[64*ib + 4*k + i + 48] |= (x4[k][16*ib + i] >> 6) & 3; + //y4[64*ib + 4*k + i + 0] |= ((x4[k][16*ib + i + 4] >> 0) & 3) << 2; + //y4[64*ib + 4*k + i + 16] |= ((x4[k][16*ib + i + 4] >> 2) & 3) << 2; + //y4[64*ib + 4*k + i + 32] |= ((x4[k][16*ib + i + 4] >> 4) & 3) << 2; + //y4[64*ib + 4*k + i + 48] |= ((x4[k][16*ib + i + 4] >> 6) & 3) << 2; + //y4[64*ib + 4*k + i + 0] |= ((x4[k][16*ib + i + 8] >> 0) & 3) << 4; + //y4[64*ib + 4*k + i + 16] |= ((x4[k][16*ib + i + 8] >> 2) & 3) << 4; + //y4[64*ib + 4*k + i + 32] |= ((x4[k][16*ib + i + 8] >> 4) & 3) << 4; + //y4[64*ib + 4*k + i + 48] |= ((x4[k][16*ib + i + 8] >> 6) & 3) << 4; + //y4[64*ib + 4*k + i + 0] |= ((x4[k][16*ib + i + 12] >> 0) & 3) << 6; + //y4[64*ib + 4*k + i + 16] |= ((x4[k][16*ib + i + 12] >> 2) & 3) << 6; + //y4[64*ib + 4*k + i + 32] |= ((x4[k][16*ib + i + 12] >> 4) & 3) << 6; + //y4[64*ib + 4*k + i + 48] |= ((x4[k][16*ib + i + 12] >> 6) & 3) << 6; + } + } + } + } +} +} + +size_t quantize_iq2_bn_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_IQ1BN == 0); + char * qcur = (char *)dst; + auto row_size = ggml_row_size(GGML_TYPE_IQ2_BN, n_per_row); + std::vector qtmp(4*row_size); + for (int row = 0; row < nrows; row += 4) { + quantize_iq2_bn(src, (void *)qtmp.data(), 4, n_per_row, imatrix); + repack_iq2_bn(4, n_per_row, qtmp.data(), qcur); + qcur += 4*row_size; + src += 4*n_per_row; + } + return nrows*row_size; +} + +void dequantize_row_iq2_bn_r4(const block_iq2_bn * x, float * y, int64_t k) { + static_assert(QK_IQ1BN == 64); + auto n_per_row = k/4; + float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row}; + const float * d4 = (const float *)x; + const uint8_t * qx = (const uint8_t *)(d4 + 4); + int nblock = n_per_row/QK_IQ1BN; + for (int ib = 0; ib < nblock; ++ib) { + for (int k = 0; k < 4; ++k) { + for (int l = 0; l < 4; ++l) for (int i = 0; i < 4; ++i) { + uint8_t q = qx[4*k + i + 16*l]; + y4[k][64*ib + 16*l + i + 0] = d4[k] * (((q >> 0) & 3) - 1); + y4[k][64*ib + 16*l + i + 4] = d4[k] * (((q >> 2) & 3) - 1); + y4[k][64*ib + 16*l + i + 8] = d4[k] * (((q >> 4) & 3) - 1); + y4[k][64*ib + 16*l + i + 12] = d4[k] * (((q >> 6) & 3) - 1); + } + } + qx += 64; + } +} + +void vec_dot_iq2_bn_r4_q8_K64(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_BN_R4, vx, 0, GGML_TYPE_Q8_K64, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +// +// ========================================= q4_k_r4 +// + +void quantize_row_q4_k_r4_ref(const float * x, block_q4_k_r4 * y, int64_t k) { + quantize_q4_k_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_q4_k_r4(const float * x, void * y, int64_t k) { + quantize_q4_k_r4(x, y, 4, k/4, nullptr); +} + +namespace { +inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t& d, uint8_t& m) { + if (j < 4) { + d = q[j] & 63; m = q[j + 4] & 63; + } else { + d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + } +} +inline void convert_q4_k(const block_q4_K& x, uint8_t * L, uint8_t * Ld, uint8_t * Lm) { + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + get_scale_min_k4(2*ib64+0, x.scales, Ld[2*ib64+0], Lm[2*ib64+0]); + get_scale_min_k4(2*ib64+1, x.scales, Ld[2*ib64+1], Lm[2*ib64+1]); + for (int j = 0; j < 32; ++j) { + L[64*ib64+j+ 0] = x.qs[32*ib64+j] & 0xf; + L[64*ib64+j+32] = x.qs[32*ib64+j] >> 4; + } + } +} +} + +static void repack_q4_k(int nrows, int n_per_row, const block_q4_K * x, block_q4_k_r4 * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + const block_q4_K * x4[4]; + uint8_t L[QK_K], Ld[QK_K/32], Lm[QK_K/32]; + for (int row = 0; row < nrows; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + for (int ibl = 0; ibl < nblock; ++ibl) { + std::memset(y[ibl].scales_l, 0, QK_K/8); + std::memset(y[ibl].scales_h, 0, QK_K/16); + for (int k = 0; k < 4; ++k) { + y[ibl].d[k+0] = x4[k][ibl].d; + y[ibl].d[k+4] = x4[k][ibl].dmin; + convert_q4_k(x4[k][ibl], L, Ld, Lm); + for (int ib = 0; ib < QK_K/32; ++ib) { + y[ibl].scales_l[4*ib+k] = (Ld[ib] & 0xf) | ((Lm[ib] & 0xf) << 4); + uint8_t h = (Ld[ib] >> 4) | ((Lm[ib] >> 4) << 2); + y[ibl].scales_h[(4*ib+k)%16] |= (h << 4*((4*ib+k)/16)); + for (int i = 0; i < 4; ++i) { + y[ibl].qs[64*ib+4*k+i+ 0] = L[32*ib+i+ 0] | (L[32*ib+i+ 8] << 4); + y[ibl].qs[64*ib+4*k+i+16] = L[32*ib+i+16] | (L[32*ib+i+24] << 4); + y[ibl].qs[64*ib+4*k+i+32] = L[32*ib+i+ 4] | (L[32*ib+i+12] << 4); + y[ibl].qs[64*ib+4*k+i+48] = L[32*ib+i+20] | (L[32*ib+i+28] << 4); + } + } + } + } + x += 4*nblock; + y += nblock; + } +} + +size_t quantize_q4_k_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + char * qcur = (char *)dst; + auto row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row); + std::vector qtmp(4*row_size); + for (int row = 0; row < nrows; row += 4) { + quantize_q4_K(src, (void *)qtmp.data(), 4, n_per_row, imatrix); + repack_q4_k(4, n_per_row, (const block_q4_K *)qtmp.data(), (block_q4_k_r4 *)qcur); + qcur += 4*row_size; + src += 4*n_per_row; + } + return nrows*row_size; +} + +void dequantize_row_q4_k_r4(const block_q4_k_r4 * x, float * y, int64_t k) { + auto n_per_row = k/4; + float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row}; + int nblock = n_per_row/QK_K; + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + const float d = GGML_FP16_TO_FP32(x[ibl].d[k+0]); + const float m = GGML_FP16_TO_FP32(x[ibl].d[k+4]); + for (int ib = 0; ib < QK_K/32; ++ib) { + int is = 4*ib + k; + float dl = d * ((x[ibl].scales_l[is] & 0xf) | (((x[ibl].scales_h[is%16] >> 4*(is/16)) & 0x03) << 4)); + float ml = m * ((x[ibl].scales_l[is] >> 4) | (((x[ibl].scales_h[is%16] >> 4*(is/16)) & 0x0c) << 2)); + for (int i = 0; i < 4; ++i) { + y4[k][QK_K*ibl+32*ib+i+ 0] = dl * (x[ibl].qs[64*ib+4*k+i+ 0] & 0xf) - ml; + y4[k][QK_K*ibl+32*ib+i+ 8] = dl * (x[ibl].qs[64*ib+4*k+i+ 0] >> 4) - ml; + y4[k][QK_K*ibl+32*ib+i+16] = dl * (x[ibl].qs[64*ib+4*k+i+16] & 0xf) - ml; + y4[k][QK_K*ibl+32*ib+i+24] = dl * (x[ibl].qs[64*ib+4*k+i+16] >> 4) - ml; + y4[k][QK_K*ibl+32*ib+i+ 4] = dl * (x[ibl].qs[64*ib+4*k+i+32] & 0xf) - ml; + y4[k][QK_K*ibl+32*ib+i+12] = dl * (x[ibl].qs[64*ib+4*k+i+32] >> 4) - ml; + y4[k][QK_K*ibl+32*ib+i+20] = dl * (x[ibl].qs[64*ib+4*k+i+48] & 0xf) - ml; + y4[k][QK_K*ibl+32*ib+i+28] = dl * (x[ibl].qs[64*ib+4*k+i+48] >> 4) - ml; + } + } + } + } +} + +void vec_dot_q4_k_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q4_K_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +// +// ========================================= q6_k_r4 +// + +void quantize_row_q6_k_r4_ref(const float * x, block_q6_k_r4 * y, int64_t k) { + quantize_q6_k_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_q6_k_r4(const float * x, void * y, int64_t k) { + quantize_q6_k_r4(x, y, 4, k/4, nullptr); +} + +namespace { +inline void convert_q6_k(const block_q6_K& x, uint8_t * L) { + const uint8_t * ql = x.ql; + const uint8_t * qh = x.qh; + + for (int n = 0; n < QK_K; n += 128) { + for (int l = 0; l < 32; ++l) { + L[n + l + 0] = (ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4); + L[n + l + 32] = (ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4); + L[n + l + 64] = (ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4); + L[n + l + 96] = (ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4); + } + ql += 64; + qh += 32; + } +} +} + +static void repack_q6_k(int nrows, int n_per_row, const block_q6_K * x, block_q6_k_r4 * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + const block_q6_K * x4[4]; + uint8_t L[QK_K]; + for (int row = 0; row < nrows; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + y[ibl].d[k] = x4[k][ibl].d; + convert_q6_k(x4[k][ibl], L); + for (int ib = 0; ib < QK_K/32; ++ib) { + y[ibl].scales[8*ib+k+0] = x4[k][ibl].scales[2*ib+0]; + y[ibl].scales[8*ib+k+4] = x4[k][ibl].scales[2*ib+1]; + for (int i = 0; i < 4; ++i) { + y[ibl].ql[64*ib+4*k+i+ 0] = (L[32*ib+i+ 0] & 0xf) | ((L[32*ib+i+ 8] & 0xf) << 4); + y[ibl].ql[64*ib+4*k+i+16] = (L[32*ib+i+16] & 0xf) | ((L[32*ib+i+24] & 0xf) << 4); + y[ibl].ql[64*ib+4*k+i+32] = (L[32*ib+i+ 4] & 0xf) | ((L[32*ib+i+12] & 0xf) << 4); + y[ibl].ql[64*ib+4*k+i+48] = (L[32*ib+i+20] & 0xf) | ((L[32*ib+i+28] & 0xf) << 4); + y[ibl].qh[32*ib+4*k+i+ 0] = (L[32*ib+i+ 0] >> 4) | ((L[32*ib+i+ 8] >> 4) << 2) | ((L[32*ib+i+ 4] >> 4) << 4) | ((L[32*ib+i+12] >> 4) << 6); + y[ibl].qh[32*ib+4*k+i+16] = (L[32*ib+i+16] >> 4) | ((L[32*ib+i+24] >> 4) << 2) | ((L[32*ib+i+20] >> 4) << 4) | ((L[32*ib+i+28] >> 4) << 6); + } + } + } + } + x += 4*nblock; + y += nblock; + } +} + +size_t quantize_q6_k_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + char * qcur = (char *)dst; + auto row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row); + std::vector qtmp(4*row_size); + for (int row = 0; row < nrows; row += 4) { + quantize_q6_K(src, (void *)qtmp.data(), 4, n_per_row, imatrix); + repack_q6_k(4, n_per_row, (const block_q6_K *)qtmp.data(), (block_q6_k_r4 *)qcur); + qcur += 4*row_size; + src += 4*n_per_row; + } + return nrows*row_size; +} + +void dequantize_row_q6_k_r4(const block_q6_k_r4 * x, float * y, int64_t k) { + auto n_per_row = k/4; + float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row}; + int nblock = n_per_row/QK_K; + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + const float d = GGML_FP16_TO_FP32(x[ibl].d[k]); + auto ql = x[ibl].ql; + auto qh = x[ibl].qh; + for (int ib = 0; ib < QK_K/32; ++ib) { + float dl1 = d * x[ibl].scales[8*ib+k+0]; + float dl2 = d * x[ibl].scales[8*ib+k+4]; + for (int i = 0; i < 4; ++i) { + y4[k][QK_K*ibl+32*ib+i+ 0] = dl1 * (((ql[4*k+i+ 0] & 0xf) | ((qh[4*k+i+ 0] << 4) & 0x30)) - 32); + y4[k][QK_K*ibl+32*ib+i+ 8] = dl1 * (((ql[4*k+i+ 0] >> 4) | ((qh[4*k+i+ 0] << 2) & 0x30)) - 32); + y4[k][QK_K*ibl+32*ib+i+16] = dl2 * (((ql[4*k+i+16] & 0xf) | ((qh[4*k+i+16] << 4) & 0x30)) - 32); + y4[k][QK_K*ibl+32*ib+i+24] = dl2 * (((ql[4*k+i+16] >> 4) | ((qh[4*k+i+16] << 2) & 0x30)) - 32); + y4[k][QK_K*ibl+32*ib+i+ 4] = dl1 * (((ql[4*k+i+32] & 0xf) | ((qh[4*k+i+ 0] >> 0) & 0x30)) - 32); + y4[k][QK_K*ibl+32*ib+i+12] = dl1 * (((ql[4*k+i+32] >> 4) | ((qh[4*k+i+ 0] >> 2) & 0x30)) - 32); + y4[k][QK_K*ibl+32*ib+i+20] = dl2 * (((ql[4*k+i+48] & 0xf) | ((qh[4*k+i+16] >> 0) & 0x30)) - 32); + y4[k][QK_K*ibl+32*ib+i+28] = dl2 * (((ql[4*k+i+48] >> 4) | ((qh[4*k+i+16] >> 2) & 0x30)) - 32); + } + ql += 64; + qh += 32; + } + } + } +} + +void vec_dot_q6_k_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q6_K_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + + +// +// ========================================= q5_k_r4 +// + +void quantize_row_q5_k_r4_ref(const float * x, block_q5_k_r4 * y, int64_t k) { + quantize_q5_k_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_q5_k_r4(const float * x, void * y, int64_t k) { + quantize_q5_k_r4(x, y, 4, k/4, nullptr); +} + +namespace { +inline void convert_q5_k(const block_q5_K& x, uint8_t * L, uint8_t * Ld, uint8_t * Lm) { + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + get_scale_min_k4(2*ib64+0, x.scales, Ld[2*ib64+0], Lm[2*ib64+0]); + get_scale_min_k4(2*ib64+1, x.scales, Ld[2*ib64+1], Lm[2*ib64+1]); + for (int j = 0; j < 32; ++j) { + L[64*ib64+j+ 0] = (x.qs[32*ib64+j] & 0xf) | (((x.qh[j] >> (2*ib64+0)) & 1) << 4); + L[64*ib64+j+32] = (x.qs[32*ib64+j] >> 4) | (((x.qh[j] >> (2*ib64+1)) & 1) << 4); + } + } +} +} + +static void repack_q5_k(int nrows, int n_per_row, const block_q5_K * x, block_q5_k_r4 * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + const block_q5_K * x4[4]; + uint8_t L[QK_K], Ld[QK_K/32], Lm[QK_K/32]; + for (int row = 0; row < nrows; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + for (int ibl = 0; ibl < nblock; ++ibl) { + std::memset(y[ibl].scales_l, 0, QK_K/8); + std::memset(y[ibl].scales_h, 0, QK_K/16); + for (int k = 0; k < 4; ++k) { + y[ibl].d[k+0] = x4[k][ibl].d; + y[ibl].d[k+4] = x4[k][ibl].dmin; + convert_q5_k(x4[k][ibl], L, Ld, Lm); + for (int ib = 0; ib < QK_K/32; ++ib) { + y[ibl].scales_l[4*ib+k] = (Ld[ib] & 0xf) | ((Lm[ib] & 0xf) << 4); + uint8_t h = (Ld[ib] >> 4) | ((Lm[ib] >> 4) << 2); + y[ibl].scales_h[(4*ib+k)%16] |= (h << 4*((4*ib+k)/16)); + for (int i = 0; i < 4; ++i) { + y[ibl].qs[64*ib+4*k+i+ 0] = (L[32*ib+i+ 0] & 0xf) | ((L[32*ib+i+ 8] & 0xf) << 4); + y[ibl].qs[64*ib+4*k+i+16] = (L[32*ib+i+16] & 0xf) | ((L[32*ib+i+24] & 0xf) << 4); + y[ibl].qs[64*ib+4*k+i+32] = (L[32*ib+i+ 4] & 0xf) | ((L[32*ib+i+12] & 0xf) << 4); + y[ibl].qs[64*ib+4*k+i+48] = (L[32*ib+i+20] & 0xf) | ((L[32*ib+i+28] & 0xf) << 4); + y[ibl].qh[16*ib+4*k+i+ 0] = ((L[32*ib+i+ 0] >> 4) << 0) | ((L[32*ib+i+ 8] >> 4) << 1) | ((L[32*ib+i+ 4] >> 4) << 2) | ((L[32*ib+i+12] >> 4) << 3) | + ((L[32*ib+i+16] >> 4) << 4) | ((L[32*ib+i+24] >> 4) << 5) | ((L[32*ib+i+20] >> 4) << 6) | ((L[32*ib+i+28] >> 4) << 7); + } + } + } + } + x += 4*nblock; + y += nblock; + } +} + +size_t quantize_q5_k_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + char * qcur = (char *)dst; + auto row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row); + std::vector qtmp(4*row_size); + for (int row = 0; row < nrows; row += 4) { + quantize_q5_K(src, (void *)qtmp.data(), 4, n_per_row, imatrix); + repack_q5_k(4, n_per_row, (const block_q5_K *)qtmp.data(), (block_q5_k_r4 *)qcur); + qcur += 4*row_size; + src += 4*n_per_row; + } + return nrows*row_size; +} + +void dequantize_row_q5_k_r4(const block_q5_k_r4 * x, float * y, int64_t k) { + auto n_per_row = k/4; + float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row}; + int nblock = n_per_row/QK_K; + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + const float d = GGML_FP16_TO_FP32(x[ibl].d[k+0]); + const float m = GGML_FP16_TO_FP32(x[ibl].d[k+4]); + auto ql = x[ibl].qs; + auto qh = x[ibl].qh; + for (int ib = 0; ib < QK_K/32; ++ib) { + int is = 4*ib + k; + float dl = d * ((x[ibl].scales_l[is] & 0xf) | (((x[ibl].scales_h[is%16] >> 4*(is/16)) & 0x03) << 4)); + float ml = m * ((x[ibl].scales_l[is] >> 4) | (((x[ibl].scales_h[is%16] >> 4*(is/16)) & 0x0c) << 2)); + for (int i = 0; i < 4; ++i) { + y4[k][QK_K*ibl+32*ib+i+ 0] = dl * ((ql[4*k+i+ 0] & 0xf) | ((qh[4*k+i] << 4) & 0x10)) - ml; + y4[k][QK_K*ibl+32*ib+i+ 8] = dl * ((ql[4*k+i+ 0] >> 4) | ((qh[4*k+i] << 3) & 0x10)) - ml; + y4[k][QK_K*ibl+32*ib+i+16] = dl * ((ql[4*k+i+16] & 0xf) | ((qh[4*k+i] >> 0) & 0x10)) - ml; + y4[k][QK_K*ibl+32*ib+i+24] = dl * ((ql[4*k+i+16] >> 4) | ((qh[4*k+i] >> 1) & 0x10)) - ml; + y4[k][QK_K*ibl+32*ib+i+ 4] = dl * ((ql[4*k+i+32] & 0xf) | ((qh[4*k+i] << 2) & 0x10)) - ml; + y4[k][QK_K*ibl+32*ib+i+12] = dl * ((ql[4*k+i+32] >> 4) | ((qh[4*k+i] << 1) & 0x10)) - ml; + y4[k][QK_K*ibl+32*ib+i+20] = dl * ((ql[4*k+i+48] & 0xf) | ((qh[4*k+i] >> 2) & 0x10)) - ml; + y4[k][QK_K*ibl+32*ib+i+28] = dl * ((ql[4*k+i+48] >> 4) | ((qh[4*k+i] >> 3) & 0x10)) - ml; + } + ql += 64; + qh += 16; + } + } + } +} + +void vec_dot_q5_k_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q5_K_R4, vx, 0, GGML_TYPE_Q8_K32, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +// +// ========================================= q3_k_r4 +// + +void quantize_row_q3_k_r4_ref(const float * x, block_q3_k_r4 * y, int64_t k) { + quantize_q3_k_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_q3_k_r4(const float * x, void * y, int64_t k) { + quantize_q3_k_r4(x, y, 4, k/4, nullptr); +} + +namespace { +inline void convert_q3_k(const block_q3_K& x, uint8_t * L, uint8_t * Ld) { + constexpr uint32_t kmask1 = 0x03030303; + constexpr uint32_t kmask2 = 0x0f0f0f0f; + uint32_t aux[4]; + memcpy(aux, x.scales, 12); + uint32_t tmp = aux[2]; + aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + std::memcpy(Ld, aux, 16); + + const uint8_t * q = x.qs; + const uint8_t * hm = x.hmask; + uint8_t m = 1; + for (int n = 0; n < QK_K; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + for (int l = 0; l < 32; ++l) { + *L++ = ((q[l] >> shift) & 3) + ((hm[l] & m) ? 4 : 0); + } + shift += 2; + m <<= 1; + } + q += 32; + } +} +} + +static void repack_q3_k(int nrows, int n_per_row, const block_q3_K * x, block_q3_k_r4 * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + const block_q3_K * x4[4]; + uint8_t L[QK_K], Ld[QK_K/16]; + for (int row = 0; row < nrows; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + for (int ibl = 0; ibl < nblock; ++ibl) { + std::memset(y[ibl].scales_l, 0, QK_K/8); + std::memset(y[ibl].scales_h, 0, QK_K/16); + for (int k = 0; k < 4; ++k) { + y[ibl].d[k] = x4[k][ibl].d; + convert_q3_k(x4[k][ibl], L, Ld); + for (int ib = 0; ib < QK_K/32; ++ib) { + int is = 8*ib+k; + y[ibl].scales_l[is%32] |= (Ld[2*ib+0] & 0xf) << 4*(is/32); + y[ibl].scales_h[is%16] |= (Ld[2*ib+0] >> 4) << 2*(is/16); + is += 4; + y[ibl].scales_l[is%32] |= (Ld[2*ib+1] & 0xf) << 4*(is/32); + y[ibl].scales_h[is%16] |= (Ld[2*ib+1] >> 4) << 2*(is/16); + for (int i = 0; i < 4; ++i) { + y[ibl].qs[32*ib+4*k+i+ 0] = ((L[32*ib+i+ 0] & 0x3) << 0) | ((L[32*ib+i+ 4] & 0x3) << 2) | ((L[32*ib+i+ 8] & 0x3) << 4) | ((L[32*ib+i+12] & 0x3) << 6); + y[ibl].qs[32*ib+4*k+i+16] = ((L[32*ib+i+16] & 0x3) << 0) | ((L[32*ib+i+20] & 0x3) << 2) | ((L[32*ib+i+24] & 0x3) << 4) | ((L[32*ib+i+28] & 0x3) << 6); + y[ibl].qh[16*ib+4*k+i+ 0] = ((L[32*ib+i+ 0] >> 2) << 0) | ((L[32*ib+i+ 4] >> 2) << 1) | ((L[32*ib+i+ 8] >> 2) << 2) | ((L[32*ib+i+12] >> 2) << 3) + | ((L[32*ib+i+16] >> 2) << 4) | ((L[32*ib+i+20] >> 2) << 5) | ((L[32*ib+i+24] >> 2) << 6) | ((L[32*ib+i+28] >> 2) << 7); + } + } + } + } + x += 4*nblock; + y += nblock; + } +} + +size_t quantize_q3_k_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + char * qcur = (char *)dst; + auto row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row); + std::vector qtmp(4*row_size); + for (int row = 0; row < nrows; row += 4) { + quantize_q3_K(src, (void *)qtmp.data(), 4, n_per_row, imatrix); + repack_q3_k(4, n_per_row, (const block_q3_K *)qtmp.data(), (block_q3_k_r4 *)qcur); + qcur += 4*row_size; + src += 4*n_per_row; + } + return nrows*row_size; +} + +void dequantize_row_q3_k_r4(const block_q3_k_r4 * x, float * y, int64_t k) { + auto n_per_row = k/4; + float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row}; + int nblock = n_per_row/QK_K; + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + const float d = GGML_FP16_TO_FP32(x[ibl].d[k]); + auto ql = x[ibl].qs; + auto qh = x[ibl].qh; + for (int ib = 0; ib < QK_K/32; ++ib) { + int is = 8*ib + k; + float dl1 = d * ((((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) | (((x[ibl].scales_h[is%16] >> 2*(is/16)) & 0x03) << 4)) - 32); + is += 4; + float dl2 = d * ((((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) | (((x[ibl].scales_h[is%16] >> 2*(is/16)) & 0x03) << 4)) - 32); + for (int i = 0; i < 4; ++i) { + y4[k][QK_K*ibl+32*ib+i+ 0] = dl1 * ((((ql[4*k+i+ 0] >> 0) & 3) | ((qh[4*k+i] << 2) & 4)) - 4); + y4[k][QK_K*ibl+32*ib+i+ 4] = dl1 * ((((ql[4*k+i+ 0] >> 2) & 3) | ((qh[4*k+i] << 1) & 4)) - 4); + y4[k][QK_K*ibl+32*ib+i+ 8] = dl1 * ((((ql[4*k+i+ 0] >> 4) & 3) | ((qh[4*k+i] << 0) & 4)) - 4); + y4[k][QK_K*ibl+32*ib+i+12] = dl1 * ((((ql[4*k+i+ 0] >> 6) & 3) | ((qh[4*k+i] >> 1) & 4)) - 4); + y4[k][QK_K*ibl+32*ib+i+16] = dl2 * ((((ql[4*k+i+16] >> 0) & 3) | ((qh[4*k+i] >> 2) & 4)) - 4); + y4[k][QK_K*ibl+32*ib+i+20] = dl2 * ((((ql[4*k+i+16] >> 2) & 3) | ((qh[4*k+i] >> 3) & 4)) - 4); + y4[k][QK_K*ibl+32*ib+i+24] = dl2 * ((((ql[4*k+i+16] >> 4) & 3) | ((qh[4*k+i] >> 4) & 4)) - 4); + y4[k][QK_K*ibl+32*ib+i+28] = dl2 * ((((ql[4*k+i+16] >> 6) & 3) | ((qh[4*k+i] >> 5) & 4)) - 4); + } + ql += 32; + qh += 16; + } + } + } +} + +void vec_dot_q3_k_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q3_K_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +// +// ========================================= q2_k_r4 +// + +void quantize_row_q2_k_r4_ref(const float * x, block_q2_k_r4 * y, int64_t k) { + quantize_q3_k_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_q2_k_r4(const float * x, void * y, int64_t k) { + quantize_q2_k_r4(x, y, 4, k/4, nullptr); +} + +namespace { +inline void convert_q2_k(const block_q2_K& x, uint8_t * L) { + + const uint8_t * qs = x.qs; + for (int n = 0; n < QK_K; n += 128) { + for (int j = 0; j < 32; ++j) { + L[n + j + 0] = (qs[j] >> 0) & 0x3; + L[n + j + 32] = (qs[j] >> 2) & 0x3; + L[n + j + 64] = (qs[j] >> 4) & 0x3; + L[n + j + 96] = (qs[j] >> 6) & 0x3; + } + qs += 32; + } +} +} + +static void repack_q2_k(int nrows, int n_per_row, const block_q2_K * x, block_q2_k_r4 * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + const block_q2_K * x4[4]; + uint8_t L[QK_K]; + for (int row = 0; row < nrows; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + y[ibl].d[k+0] = x4[k][ibl].d; + y[ibl].d[k+4] = x4[k][ibl].dmin; + for (int ib = 0; ib < QK_K/16; ++ib) { + y[ibl].scales[4*ib+k] = x4[k][ibl].scales[ib]; + } + convert_q2_k(x4[k][ibl], L); + for (int ib = 0; ib < QK_K/32; ++ib) { + for (int i = 0; i < 4; ++i) { + y[ibl].qs[32*ib+4*k+i+ 0] = ((L[32*ib+i+ 0] & 0x3) << 0) | ((L[32*ib+i+ 4] & 0x3) << 2) | ((L[32*ib+i+ 8] & 0x3) << 4) | ((L[32*ib+i+12] & 0x3) << 6); + y[ibl].qs[32*ib+4*k+i+16] = ((L[32*ib+i+16] & 0x3) << 0) | ((L[32*ib+i+20] & 0x3) << 2) | ((L[32*ib+i+24] & 0x3) << 4) | ((L[32*ib+i+28] & 0x3) << 6); + } + } + } + } + x += 4*nblock; + y += nblock; + } +} + +size_t quantize_q2_k_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + char * qcur = (char *)dst; + auto row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row); + std::vector qtmp(4*row_size); + for (int row = 0; row < nrows; row += 4) { + quantize_q2_K(src, (void *)qtmp.data(), 4, n_per_row, imatrix); + repack_q2_k(4, n_per_row, (const block_q2_K *)qtmp.data(), (block_q2_k_r4 *)qcur); + qcur += 4*row_size; + src += 4*n_per_row; + } + return nrows*row_size; +} + +void dequantize_row_q2_k_r4(const block_q2_k_r4 * x, float * y, int64_t k) { + auto n_per_row = k/4; + float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row}; + int nblock = n_per_row/QK_K; + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + const float d = GGML_FP16_TO_FP32(x[ibl].d[k+0]); + const float m = GGML_FP16_TO_FP32(x[ibl].d[k+4]); + auto ql = x[ibl].qs; + for (int ib = 0; ib < QK_K/32; ++ib) { + float dl1 = d * (x[ibl].scales[8*ib + k + 0] & 0xf); + float ml1 = m * (x[ibl].scales[8*ib + k + 0] >> 4); + float dl2 = d * (x[ibl].scales[8*ib + k + 4] & 0xf); + float ml2 = m * (x[ibl].scales[8*ib + k + 4] >> 4); + for (int i = 0; i < 4; ++i) { + y4[k][QK_K*ibl+32*ib+i+ 0] = dl1 * ((ql[4*k+i+ 0] >> 0) & 3) - ml1; + y4[k][QK_K*ibl+32*ib+i+ 4] = dl1 * ((ql[4*k+i+ 0] >> 2) & 3) - ml1; + y4[k][QK_K*ibl+32*ib+i+ 8] = dl1 * ((ql[4*k+i+ 0] >> 4) & 3) - ml1; + y4[k][QK_K*ibl+32*ib+i+12] = dl1 * ((ql[4*k+i+ 0] >> 6) & 3) - ml1; + y4[k][QK_K*ibl+32*ib+i+16] = dl2 * ((ql[4*k+i+16] >> 0) & 3) - ml2; + y4[k][QK_K*ibl+32*ib+i+20] = dl2 * ((ql[4*k+i+16] >> 2) & 3) - ml2; + y4[k][QK_K*ibl+32*ib+i+24] = dl2 * ((ql[4*k+i+16] >> 4) & 3) - ml2; + y4[k][QK_K*ibl+32*ib+i+28] = dl2 * ((ql[4*k+i+16] >> 6) & 3) - ml2; + } + ql += 32; + } + } + } +} + +void vec_dot_q2_k_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q2_K_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +// +// ========================================= iq4_k_r4 +// + +void quantize_row_iq4_k_r4_ref(const float * x, block_iq4_k_r4 * y, int64_t k) { + quantize_iq4_k_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_iq4_k_r4(const float * x, void * y, int64_t k) { + quantize_iq4_k_r4(x, y, 4, k/4, nullptr); +} + +static void repack_iq4_k(int nrows, int n_per_row, const block_iq4_k * x, block_iq4_k_r4 * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + const block_iq4_k * x4[4]; + for (int row = 0; row < nrows; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + for (int ibl = 0; ibl < nblock; ++ibl) { + std::memset(y[ibl].extra, 0, 8); + std::memset(y[ibl].scales_l, 0, QK_K/8); + std::memset(y[ibl].scales_h, 0, QK_K/16); + for (int k = 0; k < 4; ++k) { + y[ibl].d[k] = x4[k][ibl].d; + auto extra = x4[k][ibl].extra; + for (int ib = 0; ib < QK_K/32; ++ib) { + if (extra & 1) y[ibl].extra[k+0] |= (1 << ib); + if (extra & 2) y[ibl].extra[k+4] |= (1 << ib); + extra >>= 2; + uint8_t sl1 = x4[k][ibl].scales_l[ib] & 0xf; + uint8_t sl2 = x4[k][ibl].scales_l[ib] >> 4; + uint8_t sh = x4[k][ibl].scales_h[ib/2] >> 4*(ib%2); + uint8_t sh1 = (sh >> 0) & 3; + uint8_t sh2 = (sh >> 2) & 3; + int i = 8*ib + k; + y[ibl].scales_l[i%32] |= (sl1 << 4*(i/32)); + y[ibl].scales_h[i%16] |= (sh1 << 2*(i/16)); + i += 4; + y[ibl].scales_l[i%32] |= (sl2 << 4*(i/32)); + y[ibl].scales_h[i%16] |= (sh2 << 2*(i/16)); + } + } + for (int ib = 0; ib < QK_K/32; ++ib) { + for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) { + y[ibl].qs[64*ib+4*k+i+ 0] = (x4[k][ibl].qs[16*ib+i+0] & 0xf) | ((x4[k][ibl].qs[16*ib+i+ 8] & 0x0f) << 4); // 0....3 + 8...11 from each row + y[ibl].qs[64*ib+4*k+i+16] = (x4[k][ibl].qs[16*ib+i+0] >> 4) | ((x4[k][ibl].qs[16*ib+i+ 8] & 0xf0)); // 16...19 + 24...27 from each row + y[ibl].qs[64*ib+4*k+i+32] = (x4[k][ibl].qs[16*ib+i+4] & 0xf) | ((x4[k][ibl].qs[16*ib+i+12] & 0x0f) << 4); // 4....7 + 12...15 from each row + y[ibl].qs[64*ib+4*k+i+48] = (x4[k][ibl].qs[16*ib+i+4] >> 4) | ((x4[k][ibl].qs[16*ib+i+12] & 0xf0)); // 20...23 + 28...31 from each row + } + } + } + x += 4*nblock; + y += nblock; + } +} + +size_t quantize_iq4_k_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + char * qcur = (char *)dst; + auto row_size = ggml_row_size(GGML_TYPE_IQ4_K, n_per_row); + std::vector qtmp(4*row_size); + for (int row = 0; row < nrows; row += 4) { + quantize_iq4_k(src, (void *)qtmp.data(), 4, n_per_row, imatrix); + repack_iq4_k(4, n_per_row, (const block_iq4_k *)qtmp.data(), (block_iq4_k_r4 *)qcur); + qcur += 4*row_size; + src += 4*n_per_row; + } + return nrows*row_size; +} + +void dequantize_row_iq4_k_r4(const block_iq4_k_r4 * x, float * y, int64_t k) { + auto n_per_row = k/4; + float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row}; + int nblock = n_per_row/QK_K; + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + const float d = GGML_FP16_TO_FP32(x[ibl].d[k]); + for (int ib = 0; ib < QK_K/32; ++ib) { + int is = 8*ib + k; + float dl1 = d * ((((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) | (((x[ibl].scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32); + is += 4; + float dl2 = d * ((((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) | (((x[ibl].scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32); + auto values1 = iq4k_values + (x[ibl].extra[k+0] & (1 << ib) ? 16 : 0); + auto values2 = iq4k_values + (x[ibl].extra[k+4] & (1 << ib) ? 16 : 0); + for (int i = 0; i < 4; ++i) { + y4[k][QK_K*ibl+32*ib+i+ 0] = dl1 * values1[x[ibl].qs[64*ib+4*k+i+ 0] & 0xf]; + y4[k][QK_K*ibl+32*ib+i+ 8] = dl1 * values1[x[ibl].qs[64*ib+4*k+i+ 0] >> 4]; + y4[k][QK_K*ibl+32*ib+i+16] = dl2 * values2[x[ibl].qs[64*ib+4*k+i+16] & 0xf]; + y4[k][QK_K*ibl+32*ib+i+24] = dl2 * values2[x[ibl].qs[64*ib+4*k+i+16] >> 4]; + y4[k][QK_K*ibl+32*ib+i+ 4] = dl1 * values1[x[ibl].qs[64*ib+4*k+i+32] & 0xf]; + y4[k][QK_K*ibl+32*ib+i+12] = dl1 * values1[x[ibl].qs[64*ib+4*k+i+32] >> 4]; + y4[k][QK_K*ibl+32*ib+i+20] = dl2 * values2[x[ibl].qs[64*ib+4*k+i+48] & 0xf]; + y4[k][QK_K*ibl+32*ib+i+28] = dl2 * values2[x[ibl].qs[64*ib+4*k+i+48] >> 4]; + } + } + } + } +} + +void vec_dot_iq4_k_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ4_K_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +// +// ========================================= iq5_k_r4 +// + +void quantize_row_iq5_k_r4_ref(const float * x, block_iq5_k_r4 * y, int64_t k) { + quantize_iq5_k_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_iq5_k_r4(const float * x, void * y, int64_t k) { + quantize_iq5_k_r4(x, y, 4, k/4, nullptr); +} + +namespace { +inline void convert_iq5_k(const block_iq5_k& x, uint8_t * L) { + const uint8_t * qs = x.qs; + const uint8_t * qh = x.qh; + int shift = 0; + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + for (int j = 0; j < 16; ++j) { + L[j+ 0] = (qs[j+ 0] & 0xf) | (((qh[j+ 0] >> shift) & 1) << 4); + L[j+16] = (qs[j+16] & 0xf) | (((qh[j+16] >> shift) & 1) << 4); + L[j+32] = (qs[j+ 0] >> 4) | (((qh[j+ 0] >> shift) & 2) << 3); + L[j+48] = (qs[j+16] >> 4) | (((qh[j+16] >> shift) & 2) << 3); + } + L += 64; + qs += 32; + shift += 2; + if (shift == 8) { qh += 32; shift = 0; } + } +} +} + +static void repack_iq5_k(int nrows, int n_per_row, const block_iq5_k * x, block_iq5_k_r4 * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + const block_iq5_k * x4[4]; + uint8_t L[QK_K]; + for (int row = 0; row < nrows; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + for (int ibl = 0; ibl < nblock; ++ibl) { + std::memset(y[ibl].extra, 0, 8); + std::memset(y[ibl].scales_l, 0, QK_K/8); + std::memset(y[ibl].scales_h, 0, QK_K/16); + for (int k = 0; k < 4; ++k) { + y[ibl].d[k] = x4[k][ibl].d; + auto extra = x4[k][ibl].extra; + convert_iq5_k(x4[k][ibl], L); + for (int ib = 0; ib < QK_K/32; ++ib) { + if (extra & 1) y[ibl].extra[k+0] |= (1 << ib); + if (extra & 2) y[ibl].extra[k+4] |= (1 << ib); + extra >>= 2; + uint8_t sl1 = x4[k][ibl].scales_l[ib] & 0xf; + uint8_t sl2 = x4[k][ibl].scales_l[ib] >> 4; + uint8_t sh = x4[k][ibl].scales_h[ib/2] >> 4*(ib%2); + uint8_t sh1 = (sh >> 0) & 3; + uint8_t sh2 = (sh >> 2) & 3; + int i = 8*ib + k; + y[ibl].scales_l[i%32] |= (sl1 << 4*(i/32)); + y[ibl].scales_h[i%16] |= (sh1 << 2*(i/16)); + i += 4; + y[ibl].scales_l[i%32] |= (sl2 << 4*(i/32)); + y[ibl].scales_h[i%16] |= (sh2 << 2*(i/16)); + for (int i = 0; i < 4; ++i) { + y[ibl].qs[64*ib+4*k+i+ 0] = (L[32*ib+i+ 0] & 0xf) | ((L[32*ib+i+ 8] & 0xf) << 4); // 0....3 + 8...11 from each row + y[ibl].qs[64*ib+4*k+i+16] = (L[32*ib+i+16] & 0xf) | ((L[32*ib+i+24] & 0xf) << 4); // 16...19 + 24...27 from each row + y[ibl].qs[64*ib+4*k+i+32] = (L[32*ib+i+ 4] & 0xf) | ((L[32*ib+i+12] & 0xf) << 4); // 4....7 + 12...15 from each row + y[ibl].qs[64*ib+4*k+i+48] = (L[32*ib+i+20] & 0xf) | ((L[32*ib+i+28] & 0xf) << 4); // 20...23 + 28...31 from each row + y[ibl].qh[16*ib+4*k+i ] = ((L[32*ib+i+ 0] >> 4) << 0) | ((L[32*ib+i+ 8] >> 4) << 1) | ((L[32*ib+i+16] >> 4) << 2) | ((L[32*ib+i+24] >> 4) << 3) + | ((L[32*ib+i+ 4] >> 4) << 4) | ((L[32*ib+i+12] >> 4) << 5) | ((L[32*ib+i+20] >> 4) << 6) | ((L[32*ib+i+28] >> 4) << 7); + } + } + } + } + x += 4*nblock; + y += nblock; + } +} + +size_t quantize_iq5_k_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + char * qcur = (char *)dst; + auto row_size = ggml_row_size(GGML_TYPE_IQ5_K, n_per_row); + std::vector qtmp(4*row_size); + for (int row = 0; row < nrows; row += 4) { + quantize_iq5_k(src, (void *)qtmp.data(), 4, n_per_row, imatrix); + repack_iq5_k(4, n_per_row, (const block_iq5_k *)qtmp.data(), (block_iq5_k_r4 *)qcur); + qcur += 4*row_size; + src += 4*n_per_row; + } + return nrows*row_size; +} + +void dequantize_row_iq5_k_r4(const block_iq5_k_r4 * x, float * y, int64_t k) { + auto n_per_row = k/4; + float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row}; + int nblock = n_per_row/QK_K; + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + const float d = GGML_FP16_TO_FP32(x[ibl].d[k]); + for (int ib = 0; ib < QK_K/32; ++ib) { + int is = 8*ib + k; + float dl1 = d * ((((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) | (((x[ibl].scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32); + is += 4; + float dl2 = d * ((((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) | (((x[ibl].scales_h[is%16] >> 2*(is/16)) & 3) << 4)) - 32); + auto values1 = iq5nl_values + (x[ibl].extra[k+0] & (1 << ib) ? 32 : 0); + auto values2 = iq5nl_values + (x[ibl].extra[k+4] & (1 << ib) ? 32 : 0); + for (int i = 0; i < 4; ++i) { + y4[k][QK_K*ibl+32*ib+i+ 0] = dl1 * values1[(x[ibl].qs[64*ib+4*k+i+ 0] & 0xf) | (((x[ibl].qh[16*ib+4*k+i] >> 0) & 1) << 4)]; + y4[k][QK_K*ibl+32*ib+i+ 8] = dl1 * values1[(x[ibl].qs[64*ib+4*k+i+ 0] >> 4) | (((x[ibl].qh[16*ib+4*k+i] >> 1) & 1) << 4)]; + y4[k][QK_K*ibl+32*ib+i+16] = dl2 * values2[(x[ibl].qs[64*ib+4*k+i+16] & 0xf) | (((x[ibl].qh[16*ib+4*k+i] >> 2) & 1) << 4)]; + y4[k][QK_K*ibl+32*ib+i+24] = dl2 * values2[(x[ibl].qs[64*ib+4*k+i+16] >> 4) | (((x[ibl].qh[16*ib+4*k+i] >> 3) & 1) << 4)]; + y4[k][QK_K*ibl+32*ib+i+ 4] = dl1 * values1[(x[ibl].qs[64*ib+4*k+i+32] & 0xf) | (((x[ibl].qh[16*ib+4*k+i] >> 4) & 1) << 4)]; + y4[k][QK_K*ibl+32*ib+i+12] = dl1 * values1[(x[ibl].qs[64*ib+4*k+i+32] >> 4) | (((x[ibl].qh[16*ib+4*k+i] >> 5) & 1) << 4)]; + y4[k][QK_K*ibl+32*ib+i+20] = dl2 * values2[(x[ibl].qs[64*ib+4*k+i+48] & 0xf) | (((x[ibl].qh[16*ib+4*k+i] >> 6) & 1) << 4)]; + y4[k][QK_K*ibl+32*ib+i+28] = dl2 * values2[(x[ibl].qs[64*ib+4*k+i+48] >> 4) | (((x[ibl].qh[16*ib+4*k+i] >> 7) & 1) << 4)]; + } + } + } + } +} + +void vec_dot_iq5_k_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ5_K_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +// +// ========================================= q8_k_r8 +// + +void quantize_row_q8_k_r8_ref(const float * x, block_q8_k_r8 * y, int64_t k) { + quantize_q8_k_r8(x, (void *)y, 8, k/8, nullptr); +} + +void quantize_row_q8_k_r8(const float * x, void * y, int64_t k) { + quantize_q8_k_r8(x, y, 8, k/8, nullptr); +} + +static void repack_q8_k(int nrows, int n_per_row, const block_q8_K * x, block_q8_k_r8 * y) { + GGML_ASSERT(nrows%8 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + const block_q8_K * x8[8]; + for (int row = 0; row < nrows; row += 8) { + for (int k = 0; k < 8; ++k) x8[k] = x + nblock*k; + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 8; ++k) { + y[ibl].d[k] = GGML_FP32_TO_FP16(x8[k][ibl].d); + for (int ib = 0; ib < QK_K/4; ++ib) { + for (int i = 0; i < 4; ++i) y[ibl].qs[32*ib + 4*k + i] = x8[k][ibl].qs[4*ib+i]; + } + } + } + x += 8*nblock; + y += nblock; + } +} + +size_t quantize_q8_k_r8(const float * src, void * dst, int64_t nrows, int64_t n_per_row, [[maybe_unused]] const float * imatrix) { + GGML_ASSERT(nrows%8 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + char * qcur = (char *)dst; + auto row_size_0 = ggml_row_size(GGML_TYPE_Q8_K, n_per_row); + auto row_size_1 = ggml_row_size(GGML_TYPE_Q8_K_R8, n_per_row); + std::vector qtmp(8*row_size_0); + for (int row = 0; row < nrows; row += 8) { + quantize_row_q8_K32(src, (void *)qtmp.data(), 8*n_per_row); + repack_q8_k(8, n_per_row, (const block_q8_K *)qtmp.data(), (block_q8_k_r8 *)qcur); + qcur += 8*row_size_1; + src += 8*n_per_row; + } + return nrows*row_size_1; +} + +void dequantize_row_q8_k_r8(const block_q8_k_r8 * x, float * y, int64_t k) { + auto n_per_row = k/8; + float * y8[8]; + for (int k = 0; k < 8; ++k) y8[k] = y + n_per_row*k; + int nblock = n_per_row/QK_K; + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 8; ++k) { + const float d = GGML_FP16_TO_FP32(x[ibl].d[k]); + for (int ib = 0; ib < QK_K/4; ++ib) { + for (int i = 0; i < 4; ++i) { + y8[k][QK_K*ibl+4*ib+i] = d * x[ibl].qs[32*ib+4*k+i]; + } + } + } + } +} + +void vec_dot_q8_k_r8_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_Q8_K_R8, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +// +// ========================================= bf16_r4 +// +namespace { +inline ggml_bf16_t to_bf16(const float& x) { + union { float f; uint32_t u; } helper; + helper.f = x; + return ggml_bf16_t{(uint16_t)(helper.u >> 16)}; +} +inline ggml_bf16_t to_bf16(const ggml_half& x) { return to_bf16(GGML_FP16_TO_FP32(x)); } +inline ggml_bf16_t to_bf16(const ggml_bf16_t& x) { return x; } +template +void repack_bf16(int nrows, int n_per_row, const T * x, ggml_bf16_t * y) { + GGML_ASSERT(nrows%16 == 0); + GGML_ASSERT(n_per_row%2 == 0); + for (int row = 0; row < nrows; row += 16) { + for (int k = 0; k < 16; ++k) { + auto x8 = x + k*n_per_row; + for (int ib = 0; ib < n_per_row/2; ++ib) { + y[32*ib + 2*k + 0] = to_bf16(x8[2*ib+0]); + y[32*ib + 2*k + 1] = to_bf16(x8[2*ib+1]); + } + } + x += 16*n_per_row; + y += 16*n_per_row; + } +} +} + +void repack_f32_bf16_r16(const void * src, void * dst, int64_t nrows, int64_t n_per_row) { + repack_bf16(nrows, n_per_row, (const float *)src, (ggml_bf16_t *)dst); +} + +void repack_bf16_bf16_r16(const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row) { + repack_bf16(nrows, n_per_row, (const ggml_bf16_t *)src, (ggml_bf16_t *)dst); +} + +// +// ========================================= iq3_k_r4 +// + +void quantize_row_iq3_k_r4_ref(const float * x, block_iq3_k_r4 * y, int64_t k) { + quantize_iq3_k_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_iq3_k_r4(const float * x, void * y, int64_t k) { + quantize_iq3_k_r4(x, y, 4, k/4, nullptr); +} + +namespace { +inline void convert_iq3_k(const block_iq3_k& x, uint8_t * L) { + const uint8_t * qs = x.qs; + const uint8_t * qh = x.qh; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + int shift_l = 2*(ib32%4); + int shift_h = ib32%8; + for (int j = 0; j < 16; ++j) { + L[j+ 0] = ((qs[j+ 0] >> shift_l) & 3) | (((qh[j+ 0] >> shift_h) & 1) << 2); + L[j+16] = ((qs[j+16] >> shift_l) & 3) | (((qh[j+16] >> shift_h) & 1) << 2); + } + L += 32; + if (shift_l == 6) qs += 32; + } +} +} + +static void repack_iq3_k(int nrows, int n_per_row, const block_iq3_k * x, block_iq3_k_r4 * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + const block_iq3_k * x4[4]; + uint8_t L[QK_K]; + for (int row = 0; row < nrows; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + for (int ibl = 0; ibl < nblock; ++ibl) { + std::memset(y[ibl].extra, 0, 8); + std::memset(y[ibl].scales_l, 0, QK_K/8); + std::memset(y[ibl].scales_h, 0, QK_K/32); + for (int k = 0; k < 4; ++k) { + y[ibl].d[k] = x4[k][ibl].d; + auto extra = x4[k][ibl].extra; + uint16_t sh = x4[k][ibl].scales_h; + convert_iq3_k(x4[k][ibl], L); + for (int ib = 0; ib < QK_K/32; ++ib) { + if (extra & 1) y[ibl].extra[k+0] |= (1 << ib); + if (extra & 2) y[ibl].extra[k+4] |= (1 << ib); + extra >>= 2; + uint8_t sl1 = x4[k][ibl].scales_l[ib] & 0xf; + uint8_t sl2 = x4[k][ibl].scales_l[ib] >> 4; + uint8_t sh1 = (sh >> 0) & 1; + uint8_t sh2 = (sh >> 1) & 1; + sh >>= 2; + int i = 8*ib + k; + y[ibl].scales_l[i%32] |= (sl1 << 4*(i/32)); + y[ibl].scales_h[i%8 ] |= (sh1 << (i/8)); + i += 4; + y[ibl].scales_l[i%32] |= (sl2 << 4*(i/32)); + y[ibl].scales_h[i%8 ] |= (sh2 << (i/8)); + for (int i = 0; i < 4; ++i) { + y[ibl].qs[32*ib+4*k+i+ 0] = ((L[32*ib+i+ 0] & 0x3) << 0) | ((L[32*ib+i+ 4] & 0x3) << 2) | ((L[32*ib+i+ 8] & 0x3) << 4) | ((L[32*ib+i+12] & 0x3) << 6); + y[ibl].qs[32*ib+4*k+i+16] = ((L[32*ib+i+16] & 0x3) << 0) | ((L[32*ib+i+20] & 0x3) << 2) | ((L[32*ib+i+24] & 0x3) << 4) | ((L[32*ib+i+28] & 0x3) << 6); + y[ibl].qh[16*ib+4*k+i+ 0] = ((L[32*ib+i+ 0] >> 2) << 0) | ((L[32*ib+i+ 4] >> 2) << 1) | ((L[32*ib+i+ 8] >> 2) << 2) | ((L[32*ib+i+12] >> 2) << 3) + | ((L[32*ib+i+16] >> 2) << 4) | ((L[32*ib+i+20] >> 2) << 5) | ((L[32*ib+i+24] >> 2) << 6) | ((L[32*ib+i+28] >> 2) << 7); + } + } + } + } + x += 4*nblock; + y += nblock; + } +} + +size_t quantize_iq3_k_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + char * qcur = (char *)dst; + auto row_size = ggml_row_size(GGML_TYPE_IQ3_K, n_per_row); + std::vector qtmp(4*row_size); + for (int row = 0; row < nrows; row += 4) { + quantize_iq3_k(src, (void *)qtmp.data(), 4, n_per_row, imatrix); + repack_iq3_k(4, n_per_row, (const block_iq3_k *)qtmp.data(), (block_iq3_k_r4 *)qcur); + qcur += 4*row_size; + src += 4*n_per_row; + } + return nrows*row_size; +} + +void dequantize_row_iq3_k_r4(const block_iq3_k_r4 * x, float * y, int64_t k) { + auto n_per_row = k/4; + float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row}; + int nblock = n_per_row/QK_K; + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + const float d = GGML_FP16_TO_FP32(x[ibl].d[k]); + auto ql = x[ibl].qs; + auto qh = x[ibl].qh; + for (int ib = 0; ib < QK_K/32; ++ib) { + int is = 8*ib + k; + float dl1 = d * (2*((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) + 1) * ((x[ibl].scales_h[is%8] >> (is/8)) & 1 ? -1 : 1); + is += 4; + float dl2 = d * (2*((x[ibl].scales_l[is%32] >> 4*(is/32)) & 0xf) + 1) * ((x[ibl].scales_h[is%8] >> (is/8)) & 1 ? -1 : 1); + auto values1 = iq3nl_values + (x[ibl].extra[k+0] & (1 << ib) ? 8 : 0); + auto values2 = iq3nl_values + (x[ibl].extra[k+4] & (1 << ib) ? 8 : 0); + for (int i = 0; i < 4; ++i) { + y4[k][QK_K*ibl+32*ib+i+ 0] = dl1 * values1[((ql[4*k+i+ 0] >> 0) & 3) | ((qh[4*k+i] << 2) & 4)]; + y4[k][QK_K*ibl+32*ib+i+ 4] = dl1 * values1[((ql[4*k+i+ 0] >> 2) & 3) | ((qh[4*k+i] << 1) & 4)]; + y4[k][QK_K*ibl+32*ib+i+ 8] = dl1 * values1[((ql[4*k+i+ 0] >> 4) & 3) | ((qh[4*k+i] << 0) & 4)]; + y4[k][QK_K*ibl+32*ib+i+12] = dl1 * values1[((ql[4*k+i+ 0] >> 6) & 3) | ((qh[4*k+i] >> 1) & 4)]; + y4[k][QK_K*ibl+32*ib+i+16] = dl2 * values2[((ql[4*k+i+16] >> 0) & 3) | ((qh[4*k+i] >> 2) & 4)]; + y4[k][QK_K*ibl+32*ib+i+20] = dl2 * values2[((ql[4*k+i+16] >> 2) & 3) | ((qh[4*k+i] >> 3) & 4)]; + y4[k][QK_K*ibl+32*ib+i+24] = dl2 * values2[((ql[4*k+i+16] >> 4) & 3) | ((qh[4*k+i] >> 4) & 4)]; + y4[k][QK_K*ibl+32*ib+i+28] = dl2 * values2[((ql[4*k+i+16] >> 6) & 3) | ((qh[4*k+i] >> 5) & 4)]; + } + ql += 32; + qh += 16; + } + } + } +} + +void vec_dot_iq3_k_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ3_K_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +// +// ========================================= iq2_k_r4 +// + +void quantize_row_iq2_k_r4_ref(const float * x, block_iq2_k_r4 * y, int64_t k) { + quantize_iq2_k_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_iq2_k_r4(const float * x, void * y, int64_t k) { + quantize_iq2_k_r4(x, y, 4, k/4, nullptr); +} + +namespace { +inline void convert_iq2_k(const block_iq2_k& x, uint8_t * L) { + const uint8_t * qs = x.qs; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + int shift_l = 2*(ib32%4); + for (int j = 0; j < 16; ++j) { + L[j+ 0] = ((qs[j+ 0] >> shift_l) & 3); + L[j+16] = ((qs[j+16] >> shift_l) & 3); + } + L += 32; + if (shift_l == 6) qs += 32; + } +} +} + +static void repack_iq2_k(int nrows, int n_per_row, const block_iq2_k * x, block_iq2_k_r4 * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + const block_iq2_k * x4[4]; + uint8_t L[QK_K]; + for (int row = 0; row < nrows; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + for (int ibl = 0; ibl < nblock; ++ibl) { + std::memset(y[ibl].extra, 0, 8); + std::memset(y[ibl].scales, 0, QK_K/8); + for (int k = 0; k < 4; ++k) { + y[ibl].d[k] = x4[k][ibl].d; + auto extra = x4[k][ibl].extra; + convert_iq2_k(x4[k][ibl], L); + for (int ib = 0; ib < QK_K/32; ++ib) { + if (extra & 1) y[ibl].extra[k+0] |= (1 << ib); + if (extra & 2) y[ibl].extra[k+4] |= (1 << ib); + extra >>= 2; + uint8_t sl1 = x4[k][ibl].scales[ib] & 0xf; + uint8_t sl2 = x4[k][ibl].scales[ib] >> 4; + int i = 8*ib + k; + y[ibl].scales[i%32] |= (sl1 << 4*(i/32)); + i += 4; + y[ibl].scales[i%32] |= (sl2 << 4*(i/32)); + for (int i = 0; i < 4; ++i) { + y[ibl].qs[32*ib+4*k+i+ 0] = ((L[32*ib+i+ 0] & 0x3) << 0) | ((L[32*ib+i+ 4] & 0x3) << 2) | ((L[32*ib+i+ 8] & 0x3) << 4) | ((L[32*ib+i+12] & 0x3) << 6); + y[ibl].qs[32*ib+4*k+i+16] = ((L[32*ib+i+16] & 0x3) << 0) | ((L[32*ib+i+20] & 0x3) << 2) | ((L[32*ib+i+24] & 0x3) << 4) | ((L[32*ib+i+28] & 0x3) << 6); + } + } + } + } + x += 4*nblock; + y += nblock; + } +} + +size_t quantize_iq2_k_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + char * qcur = (char *)dst; + auto row_size = ggml_row_size(GGML_TYPE_IQ2_K, n_per_row); + std::vector qtmp(4*row_size); + for (int row = 0; row < nrows; row += 4) { + quantize_iq2_k(src, (void *)qtmp.data(), 4, n_per_row, imatrix); + repack_iq2_k(4, n_per_row, (const block_iq2_k *)qtmp.data(), (block_iq2_k_r4 *)qcur); + qcur += 4*row_size; + src += 4*n_per_row; + } + return nrows*row_size; +} + +void dequantize_row_iq2_k_r4(const block_iq2_k_r4 * x, float * y, int64_t k) { + auto n_per_row = k/4; + float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row}; + int nblock = n_per_row/QK_K; + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + const float d = GGML_FP16_TO_FP32(x[ibl].d[k]); + auto ql = x[ibl].qs; + for (int ib = 0; ib < QK_K/32; ++ib) { + int is = 8*ib + k; + float dl1 = d * (((x[ibl].scales[is%32] >> 4*(is/32)) & 0xf) - 8); + is += 4; + float dl2 = d * (((x[ibl].scales[is%32] >> 4*(is/32)) & 0xf) - 8); + auto values1 = iq2nl_values + (x[ibl].extra[k+0] & (1 << ib) ? 4 : 0); + auto values2 = iq2nl_values + (x[ibl].extra[k+4] & (1 << ib) ? 4 : 0); + for (int i = 0; i < 4; ++i) { + y4[k][QK_K*ibl+32*ib+i+ 0] = dl1 * values1[(ql[4*k+i+ 0] >> 0) & 3]; + y4[k][QK_K*ibl+32*ib+i+ 4] = dl1 * values1[(ql[4*k+i+ 0] >> 2) & 3]; + y4[k][QK_K*ibl+32*ib+i+ 8] = dl1 * values1[(ql[4*k+i+ 0] >> 4) & 3]; + y4[k][QK_K*ibl+32*ib+i+12] = dl1 * values1[(ql[4*k+i+ 0] >> 6) & 3]; + y4[k][QK_K*ibl+32*ib+i+16] = dl2 * values2[(ql[4*k+i+16] >> 0) & 3]; + y4[k][QK_K*ibl+32*ib+i+20] = dl2 * values2[(ql[4*k+i+16] >> 2) & 3]; + y4[k][QK_K*ibl+32*ib+i+24] = dl2 * values2[(ql[4*k+i+16] >> 4) & 3]; + y4[k][QK_K*ibl+32*ib+i+28] = dl2 * values2[(ql[4*k+i+16] >> 6) & 3]; + } + ql += 32; + } + } + } +} + +void vec_dot_iq2_k_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_K_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +namespace { +struct Repack { + using repack_func = void (*) (int nrows, int n_per_row, const char * src, char * dst); + ggml_type new_type; + int num_rows; + repack_func repack; +}; +} + +namespace { +inline uint8_t scrambled_sign(uint8_t s) { + static const uint8_t k_table[128] = { + 0x00, 0x7f, 0x7e, 0x01, 0x7c, 0x03, 0x02, 0x7d, 0x78, 0x07, 0x06, 0x79, 0x04, 0x7b, 0x7a, 0x05, + 0x70, 0x0f, 0x0e, 0x71, 0x0c, 0x73, 0x72, 0x0d, 0x08, 0x77, 0x76, 0x09, 0x74, 0x0b, 0x0a, 0x75, + 0x60, 0x1f, 0x1e, 0x61, 0x1c, 0x63, 0x62, 0x1d, 0x18, 0x67, 0x66, 0x19, 0x64, 0x1b, 0x1a, 0x65, + 0x10, 0x6f, 0x6e, 0x11, 0x6c, 0x13, 0x12, 0x6d, 0x68, 0x17, 0x16, 0x69, 0x14, 0x6b, 0x6a, 0x15, + 0x40, 0x3f, 0x3e, 0x41, 0x3c, 0x43, 0x42, 0x3d, 0x38, 0x47, 0x46, 0x39, 0x44, 0x3b, 0x3a, 0x45, + 0x30, 0x4f, 0x4e, 0x31, 0x4c, 0x33, 0x32, 0x4d, 0x48, 0x37, 0x36, 0x49, 0x34, 0x4b, 0x4a, 0x35, + 0x20, 0x5f, 0x5e, 0x21, 0x5c, 0x23, 0x22, 0x5d, 0x58, 0x27, 0x26, 0x59, 0x24, 0x5b, 0x5a, 0x25, + 0x50, 0x2f, 0x2e, 0x51, 0x2c, 0x53, 0x52, 0x2d, 0x28, 0x57, 0x56, 0x29, 0x54, 0x2b, 0x2a, 0x55, + }; + return k_table[s]; +} +} + +// +// ========================================= iq2_xxs_r4 +// + +void quantize_row_iq2_xxs_r4_ref(const float * x, block_iq2_xxs_r4 * y, int64_t k) { + quantize_iq2_xxs_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_iq2_xxs_r4(const float * x, void * y, int64_t k) { + quantize_iq2_xxs_r4(x, y, 4, k/4, nullptr); +} + +static void repack_iq2_xxs(int nrows, int n_per_row, const block_iq2_xxs * x, block_iq2_xxs_r4 * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + const block_iq2_xxs * x4[4]; + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; + for (int row = 0; row < nrows; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + for (int ibl = 0; ibl < nblock; ++ibl) { + auto ysas = (uint32_t *)y[ibl].sas; + for (int k = 0; k < 4; ++k) { + y[ibl].d[k] = x4[k][ibl].d; + for (int ib = 0; ib < QK_K/32; ++ib) { + std::memcpy(aux32, x4[k][ibl].qs + 4*ib, 2*sizeof(uint32_t)); + for (int i = 0; i < 4; ++i) { + y[ibl].qs[16*ib+4*k+i] = aux8[i]; + } + uint8_t scale = aux32[1] >> 28; + uint8_t s1 = (scrambled_sign((aux32[1] >> 0) & 127) << 1) | ((scale >> 0) & 1); + uint8_t s2 = (scrambled_sign((aux32[1] >> 7) & 127) << 1) | ((scale >> 1) & 1); + uint8_t s3 = (scrambled_sign((aux32[1] >> 14) & 127) << 1) | ((scale >> 2) & 1); + uint8_t s4 = (scrambled_sign((aux32[1] >> 21) & 127) << 1) | ((scale >> 3) & 1); + aux32[1] = uint32_t(s1) | (uint32_t(s2) << 8) | (uint32_t(s3) << 16) | (uint32_t(s4) << 24); + ysas[4*ib+k] = aux32[1]; + } + } + } + x += 4*nblock; + y += nblock; + } +} + +size_t quantize_iq2_xxs_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + char * qcur = (char *)dst; + auto row_size = ggml_row_size(GGML_TYPE_IQ2_XXS, n_per_row); + std::vector qtmp(4*row_size); + for (int row = 0; row < nrows; row += 4) { + quantize_iq2_xxs(src, (void *)qtmp.data(), 4, n_per_row, imatrix); + repack_iq2_xxs(4, n_per_row, (const block_iq2_xxs *)qtmp.data(), (block_iq2_xxs_r4 *)qcur); + qcur += 4*row_size; + src += 4*n_per_row; + } + return nrows*row_size; +} + +void dequantize_row_iq2_xxs_r4(const block_iq2_xxs_r4 * x, float * y, int64_t k) { + auto n_per_row = k/4; + float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row}; + int nblock = n_per_row/QK_K; + uint32_t s32; + const uint8_t * s8 = (const uint8_t *)&s32; + for (int ibl = 0; ibl < nblock; ++ibl) { + const uint32_t * sas = (const uint32_t *)x[ibl].sas; + for (int k = 0; k < 4; ++k) { + const float d = 0.125f*GGML_FP16_TO_FP32(x[ibl].d[k]); + for (int ib = 0; ib < QK_K/32; ++ib) { + uint32_t aux32 = sas[4*ib+k]; + s32 = aux32 & 0x01010101; + uint8_t scale = s8[0] | (s8[1] << 1) | (s8[2] << 2) | (s8[3] << 3); + float dl = d*(2*scale+1); + aux32 &= 0xfefefefe; + aux32 ^= (aux32 >> 1); + for (int i = 0; i < 4; ++i) { + auto val = (const int8_t *)(iq2xxs_grid + x[ibl].qs[16*ib+4*k+i]); + for (int j = 0; j < 8; ++j) y4[k][QK_K*ibl+32*ib+8*i+j] = dl * val[j] * (aux32 & (1 << j) ? -1 : 1); + aux32 >>= 8; + } + } + } + } +} + +void vec_dot_iq2_xxs_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_XXS_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +// +// ========================================= iq2_xs_r4 +// + +void quantize_row_iq2_xs_r4_ref(const float * x, block_iq2_xs_r4 * y, int64_t k) { + quantize_iq2_xs_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_iq2_xs_r4(const float * x, void * y, int64_t k) { + quantize_iq2_xs_r4(x, y, 4, k/4, nullptr); +} + +static void repack_iq2_xs(int nrows, int n_per_row, const block_iq2_xs * x, block_iq2_xs_r4 * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + const block_iq2_xs * x4[4]; + for (int row = 0; row < nrows; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + y[ibl].d[k] = x4[k][ibl].d; + for (int ib = 0; ib < QK_K/32; ++ib) { + for (int i = 0; i < 4; ++i) { + uint16_t v = x4[k][ibl].qs[4*ib+i]; + uint8_t s = v >> 9; + y[ibl].qs[16*ib+4*k+i] = (v & 511) | (scrambled_sign(s) << 9); + } + y[ibl].scales[4*ib+k] = x4[k][ibl].scales[ib]; + } + } + } + x += 4*nblock; + y += nblock; + } +} + +size_t quantize_iq2_xs_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + char * qcur = (char *)dst; + auto row_size = ggml_row_size(GGML_TYPE_IQ2_XS, n_per_row); + std::vector qtmp(4*row_size); + for (int row = 0; row < nrows; row += 4) { + quantize_iq2_xs(src, (void *)qtmp.data(), 4, n_per_row, imatrix); + repack_iq2_xs(4, n_per_row, (const block_iq2_xs *)qtmp.data(), (block_iq2_xs_r4 *)qcur); + qcur += 4*row_size; + src += 4*n_per_row; + } + return nrows*row_size; +} + +void dequantize_row_iq2_xs_r4(const block_iq2_xs_r4 * x, float * y, int64_t k) { + auto n_per_row = k/4; + float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row}; + int nblock = n_per_row/QK_K; + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + const float d = 0.125f*GGML_FP16_TO_FP32(x[ibl].d[k]); + for (int ib = 0; ib < QK_K/32; ++ib) { + float dl1 = d * (2*(x[ibl].scales[4*ib+k] & 0xf) + 1); + float dl2 = d * (2*(x[ibl].scales[4*ib+k] >> 4) + 1); + for (int i = 0; i < 4; ++i) { + auto val = (const int8_t *)(iq2xs_grid + (x[ibl].qs[16*ib+4*k+i] & 511)); + auto signs = x[ibl].qs[16*ib+4*k+i] >> 9; + signs ^= (signs << 1); + float dl = i < 2 ? dl1 : dl2; + for (int j = 0; j < 8; ++j) y4[k][QK_K*ibl+32*ib+8*i+j] = dl * val[j] * (signs & (1 << j) ? -1 : 1); + } + } + } + } +} + +void vec_dot_iq2_xs_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_XS_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +// +// ========================================= iq2_s_r4 +// + +void quantize_row_iq2_s_r4_ref(const float * x, block_iq2_s_r4 * y, int64_t k) { + quantize_iq2_s_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_iq2_s_r4(const float * x, void * y, int64_t k) { + quantize_iq2_s_r4(x, y, 4, k/4, nullptr); +} + +static void repack_iq2_s(int nrows, int n_per_row, const block_iq2_s * x, block_iq2_s_r4 * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + const block_iq2_s * x4[4]; + for (int row = 0; row < nrows; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + auto signs = x4[k][ibl].qs + QK_K/8; + y[ibl].d[k] = x4[k][ibl].d; + for (int ib = 0; ib < QK_K/32; ++ib) { + y[ibl].scales[4*ib+k] = x4[k][ibl].scales[ib]; + for (int i = 0; i < 4; ++i) { + y[ibl].qs[16*ib+4*k+i] = x4[k][ibl].qs[4*ib+i]; + y[ibl].signs[16*ib+4*k+i] = signs[4*ib+i]; + } + y[ibl].qh[4*ib+k] = x4[k][ibl].qh[ib]; + } + } + } + x += 4*nblock; + y += nblock; + } +} + +size_t quantize_iq2_s_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + char * qcur = (char *)dst; + auto row_size = ggml_row_size(GGML_TYPE_IQ2_S, n_per_row); + std::vector qtmp(4*row_size); + for (int row = 0; row < nrows; row += 4) { + quantize_iq2_s(src, (void *)qtmp.data(), 4, n_per_row, imatrix); + repack_iq2_s(4, n_per_row, (const block_iq2_s *)qtmp.data(), (block_iq2_s_r4 *)qcur); + qcur += 4*row_size; + src += 4*n_per_row; + } + return nrows*row_size; +} + +void dequantize_row_iq2_s_r4(const block_iq2_s_r4 * x, float * y, int64_t k) { + auto n_per_row = k/4; + float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row}; + int nblock = n_per_row/QK_K; + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + const float d = 0.125f*GGML_FP16_TO_FP32(x[ibl].d[k]); + for (int ib = 0; ib < QK_K/32; ++ib) { + float dl1 = d * (2*(x[ibl].scales[4*ib+k] & 0xf) + 1); + float dl2 = d * (2*(x[ibl].scales[4*ib+k] >> 4) + 1); + for (int i = 0; i < 4; ++i) { + auto val = (const int8_t *)(iq2s_grid + (x[ibl].qs[16*ib+4*k+i] | ((x[ibl].qh[4*ib+k] << (8 - 2*i)) & 0x300))); + auto signs = x[ibl].signs[16*ib+4*k+i]; + float dl = i < 2 ? dl1 : dl2; + for (int j = 0; j < 8; ++j) y4[k][QK_K*ibl+32*ib+8*i+j] = dl * val[j] * (signs & (1 << j) ? -1 : 1); + } + } + } + } +} + +void vec_dot_iq2_s_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_S_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +// +// ========================================= iq3_xxs_r4 +// + +void quantize_row_iq3_xxs_r4_ref(const float * x, block_iq3_xxs_r4 * y, int64_t k) { + quantize_iq3_xxs_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_iq3_xxs_r4(const float * x, void * y, int64_t k) { + quantize_iq3_xxs_r4(x, y, 4, k/4, nullptr); +} + +namespace { +} + +static void repack_iq3_xxs(int nrows, int n_per_row, const block_iq3_xxs * x, block_iq3_xxs_r4 * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + const block_iq3_xxs * x4[4]; + uint32_t aux32; + for (int row = 0; row < nrows; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + for (int ibl = 0; ibl < nblock; ++ibl) { + auto ysas = (uint32_t *)y[ibl].sas; + for (int k = 0; k < 4; ++k) { + y[ibl].d[k] = x4[k][ibl].d; + auto xsas = x4[k][ibl].qs + QK_K/4; + for (int ib = 0; ib < QK_K/32; ++ib) { + for (int i = 0; i < 8; ++i) { + y[ibl].qs[32*ib+8*k+i] = x4[k][ibl].qs[8*ib+i]; + } + std::memcpy(&aux32, xsas + 4*ib, 4); + uint8_t scale = aux32 >> 28; + uint8_t s1 = (scrambled_sign((aux32 >> 0) & 127) << 1) | ((scale >> 0) & 1); + uint8_t s2 = (scrambled_sign((aux32 >> 7) & 127) << 1) | ((scale >> 1) & 1); + uint8_t s3 = (scrambled_sign((aux32 >> 14) & 127) << 1) | ((scale >> 2) & 1); + uint8_t s4 = (scrambled_sign((aux32 >> 21) & 127) << 1) | ((scale >> 3) & 1); + aux32 = uint32_t(s1) | (uint32_t(s2) << 8) | (uint32_t(s3) << 16) | (uint32_t(s4) << 24); + ysas[4*ib+k] = aux32; + } + } + } + x += 4*nblock; + y += nblock; + } +} + +size_t quantize_iq3_xxs_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + char * qcur = (char *)dst; + auto row_size = ggml_row_size(GGML_TYPE_IQ3_XXS, n_per_row); + std::vector qtmp(4*row_size); + for (int row = 0; row < nrows; row += 4) { + quantize_iq3_xxs(src, (void *)qtmp.data(), 4, n_per_row, imatrix); + repack_iq3_xxs(4, n_per_row, (const block_iq3_xxs *)qtmp.data(), (block_iq3_xxs_r4 *)qcur); + qcur += 4*row_size; + src += 4*n_per_row; + } + return nrows*row_size; +} + +void dequantize_row_iq3_xxs_r4(const block_iq3_xxs_r4 * x, float * y, int64_t k) { + auto n_per_row = k/4; + float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row}; + int nblock = n_per_row/QK_K; + uint32_t s32; + const uint8_t * s8 = (const uint8_t *)&s32; + for (int ibl = 0; ibl < nblock; ++ibl) { + const uint32_t * sas = (const uint32_t *)x[ibl].sas; + for (int k = 0; k < 4; ++k) { + const float d = 0.25f*GGML_FP16_TO_FP32(x[ibl].d[k]); + for (int ib = 0; ib < QK_K/32; ++ib) { + uint32_t aux32 = sas[4*ib+k]; + s32 = aux32 & 0x01010101; + uint8_t scale = s8[0] | (s8[1] << 1) | (s8[2] << 2) | (s8[3] << 3); + float dl = d*(2*scale+1); + aux32 &= 0xfefefefe; + aux32 ^= (aux32 >> 1); + for (int i = 0; i < 8; ++i) { + auto val = (const int8_t *)(iq3xxs_grid + x[ibl].qs[32*ib+8*k+i]); + for (int j = 0; j < 4; ++j) y4[k][QK_K*ibl+32*ib+4*i+j] = dl * val[j] * (aux32 & (1 << j) ? -1 : 1); + aux32 >>= 4; + } + } + } + } +} + +void vec_dot_iq3_xxs_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ3_XXS_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +// +// ========================================= iq3_s_r4 +// + +void quantize_row_iq3_s_r4_ref(const float * x, block_iq3_s_r4 * y, int64_t k) { + quantize_iq3_s_r4(x, (void *)y, 4, k/4, nullptr); +} + +void quantize_row_iq3_s_r4(const float * x, void * y, int64_t k) { + quantize_iq3_s_r4(x, y, 4, k/4, nullptr); +} + +static void repack_iq3_s(int nrows, int n_per_row, const block_iq3_s * x, block_iq3_s_r4 * y) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + int nblock = n_per_row/QK_K; + const block_iq3_s * x4[4]; + for (int row = 0; row < nrows; row += 4) { + for (int k = 0; k < 4; ++k) x4[k] = x + nblock*k; + for (int ibl = 0; ibl < nblock; ++ibl) { + std::memset(y[ibl].scales, 0, QK_K/16); + std::memset(y[ibl].signs, 0, QK_K/2); + std::memset(y[ibl].qh, 0, QK_K/8); + for (int k = 0; k < 4; ++k) { + y[ibl].d[k] = x4[k][ibl].d; + for (int ib = 0; ib < QK_K/64; ++ib) { + int j = 8*ib + k; + y[ibl].scales[(j+0)%16] |= ((x4[k][ibl].scales[ib] & 0xf) << 4*((j+0)/16)); + y[ibl].scales[(j+4)%16] |= ((x4[k][ibl].scales[ib] >> 4) << 4*((j+4)/16)); + } + for (int ib = 0; ib < QK_K/32; ++ib) { + y[ibl].qh[4*ib+k] = x4[k][ibl].qh[ib]; // leave ot like this? + for (int i = 0; i < 4; ++i) { + y[ibl].qs[32*ib+k+8*i+0] = x4[k][ibl].qs[8*ib+i+0]; + y[ibl].qs[32*ib+k+8*i+4] = x4[k][ibl].qs[8*ib+i+4]; + } + for (int i = 0; i < 4; ++i) { + y[ibl].signs[16*ib+4*k+i] = (((x4[k][ibl].signs[4*ib+0] >> i) & 1) << 0) | (((x4[k][ibl].signs[4*ib+0] >> (4+i)) & 1) << 1) | + (((x4[k][ibl].signs[4*ib+1] >> i) & 1) << 2) | (((x4[k][ibl].signs[4*ib+1] >> (4+i)) & 1) << 3) | + (((x4[k][ibl].signs[4*ib+2] >> i) & 1) << 4) | (((x4[k][ibl].signs[4*ib+2] >> (4+i)) & 1) << 5) | + (((x4[k][ibl].signs[4*ib+3] >> i) & 1) << 6) | (((x4[k][ibl].signs[4*ib+3] >> (4+i)) & 1) << 7); + } + } + } + } + x += 4*nblock; + y += nblock; + } +} + +size_t quantize_iq3_s_r4(const float * src, void * dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_ASSERT(nrows%4 == 0); + GGML_ASSERT(n_per_row%QK_K == 0); + char * qcur = (char *)dst; + auto row_size = ggml_row_size(GGML_TYPE_IQ3_S, n_per_row); + std::vector qtmp(4*row_size); + for (int row = 0; row < nrows; row += 4) { + quantize_iq3_s(src, (void *)qtmp.data(), 4, n_per_row, imatrix); + repack_iq3_s(4, n_per_row, (const block_iq3_s *)qtmp.data(), (block_iq3_s_r4 *)qcur); + qcur += 4*row_size; + src += 4*n_per_row; + } + return nrows*row_size; +} + +void dequantize_row_iq3_s_r4(const block_iq3_s_r4 * x, float * y, int64_t k) { + auto n_per_row = k/4; + float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row}; + int nblock = n_per_row/QK_K; + for (int ibl = 0; ibl < nblock; ++ibl) { + for (int k = 0; k < 4; ++k) { + const float d = GGML_FP16_TO_FP32(x[ibl].d[k]); + for (int ib = 0; ib < QK_K/32; ++ib) { + int l = 4*ib + k; + float dl = d * (1 + 2*((x[ibl].scales[l%16] >> 4*(l/16)) & 0xf)); + for (int i = 0; i < 4; ++i) { + auto grid1 = (const uint8_t *)(iq3s_grid + x[ibl].qs[32*ib+k+8*i+0] + ((x[ibl].qh[4*ib+k] << (8-i)) & 0x100)); + auto grid2 = (const uint8_t *)(iq3s_grid + x[ibl].qs[32*ib+k+8*i+4] + ((x[ibl].qh[4*ib+k] << (4-i)) & 0x100)); + for (int j = 0; j < 4; ++j) { + y4[k][QK_K*ibl+32*ib+4*i+ 0+j] = dl * grid1[j] * (x[ibl].signs[16*ib+4*k+j] & (1 << (i+0)) ? -1 : 1); + y4[k][QK_K*ibl+32*ib+4*i+16+j] = dl * grid2[j] * (x[ibl].signs[16*ib+4*k+j] & (1 << (i+4)) ? -1 : 1); + } + } + } + } + } +} + +void vec_dot_iq3_s_r4_q8_k(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if GGML_USE_IQK_MULMAT + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ3_S_R4, vx, 0, GGML_TYPE_Q8_K, vy, 0, s, 0, 0, 1)) { + return; + } +#endif + GGML_ASSERT(n%QK4_NL == 0); + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); + GGML_UNUSED(bx); + GGML_UNUSED(by); +} + +//================================================ + +void iqk_repack_tensor(struct ggml_tensor * tensor) { + constexpr int kChunk = 8; + if (!tensor) return; + if (!ggml_is_contiguous(tensor)) return; + if (strncmp(tensor->name, "token_embd.weight", GGML_MAX_NAME) == 0) return; + if (tensor->ne[1] % 4 || tensor->ne[2]*tensor->ne[3] > 1) return; + static const std::unordered_map k_map = { + { GGML_TYPE_IQ2_K, { GGML_TYPE_IQ2_K_R4, 4, (Repack::repack_func)repack_iq2_k} }, + { GGML_TYPE_IQ3_K, { GGML_TYPE_IQ3_K_R4, 4, (Repack::repack_func)repack_iq3_k} }, + { GGML_TYPE_IQ4_K, { GGML_TYPE_IQ4_K_R4, 4, (Repack::repack_func)repack_iq4_k} }, + { GGML_TYPE_IQ5_K, { GGML_TYPE_IQ5_K_R4, 4, (Repack::repack_func)repack_iq5_k} }, + { GGML_TYPE_IQ4_XS, { GGML_TYPE_IQ4_XS_R4, 4, (Repack::repack_func)repack_iq4_xs} }, + { GGML_TYPE_IQ4_KS, { GGML_TYPE_IQ4_KS_R4, 4, (Repack::repack_func)repack_iq4_ks} }, + { GGML_TYPE_IQ4_NL, { GGML_TYPE_IQ4_NL_R4, 4, (Repack::repack_func)repack_iq4_nl} }, + { GGML_TYPE_IQ2_BN, { GGML_TYPE_IQ2_BN_R4, 4, (Repack::repack_func)repack_iq2_bn} }, + { GGML_TYPE_IQ2_XXS,{ GGML_TYPE_IQ2_XXS_R4,4, (Repack::repack_func)repack_iq2_xxs} }, + { GGML_TYPE_IQ2_XS, { GGML_TYPE_IQ2_XS_R4, 4, (Repack::repack_func)repack_iq2_xs} }, + { GGML_TYPE_IQ2_S, { GGML_TYPE_IQ2_S_R4, 4, (Repack::repack_func)repack_iq2_s} }, + { GGML_TYPE_IQ3_XXS,{ GGML_TYPE_IQ3_XXS_R4,4, (Repack::repack_func)repack_iq3_xxs} }, + { GGML_TYPE_IQ3_S, { GGML_TYPE_IQ3_S_R4, 4, (Repack::repack_func)repack_iq3_s} }, + { GGML_TYPE_Q2_K, { GGML_TYPE_Q2_K_R4, 4, (Repack::repack_func)repack_q2_k} }, + { GGML_TYPE_Q3_K, { GGML_TYPE_Q3_K_R4, 4, (Repack::repack_func)repack_q3_k} }, + { GGML_TYPE_Q4_K, { GGML_TYPE_Q4_K_R4, 4, (Repack::repack_func)repack_q4_k} }, + { GGML_TYPE_Q5_K, { GGML_TYPE_Q5_K_R4, 4, (Repack::repack_func)repack_q5_k} }, + { GGML_TYPE_Q6_K, { GGML_TYPE_Q6_K_R4, 4, (Repack::repack_func)repack_q6_k} }, + { GGML_TYPE_Q4_0, { GGML_TYPE_Q4_0_R4, 4, (Repack::repack_func)repack_q4_0} }, + { GGML_TYPE_Q5_0, { GGML_TYPE_Q5_0_R4, 4, (Repack::repack_func)repack_q5_0} }, + { GGML_TYPE_Q6_0, { GGML_TYPE_Q6_0_R4, 4, (Repack::repack_func)repack_q6_0} }, + { GGML_TYPE_Q8_0, { GGML_TYPE_Q8_0_R4, 4, (Repack::repack_func)repack_q8_0} }, + { GGML_TYPE_Q8_K, { GGML_TYPE_Q8_K_R8, 8, (Repack::repack_func)repack_q8_k} }, +#ifdef __AVX512BF16__ + { GGML_TYPE_BF16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16}}, + { GGML_TYPE_F16, { GGML_TYPE_BF16_R16, 16, (Repack::repack_func)repack_bf16} }, +#endif + }; + + auto it = k_map.find(tensor->type); + if (it == k_map.end()) return; + if (tensor->ne[1] % it->second.num_rows) return; + + auto& r = it->second; + + int max_thread = std::max(1, int(std::thread::hardware_concurrency()/2)); + int num_chunks = (tensor->ne[1] + kChunk*r.num_rows - 1)/(kChunk*r.num_rows); + int nthread = std::min(num_chunks, max_thread); + + //printf("%s(%s): %s -> %s. %d rows, %d chunks, %d threads\n", __func__, tensor->name, ggml_type_name(tensor->type), ggml_type_name(r.new_type), + // int(tensor->ne[1]), num_chunks, nthread); + + std::atomic counter(0);; + auto compute = [&counter, &r, tensor, num_chunks, chunkSize = kChunk] () { + int nrows = tensor->ne[1]; + int n_per_row = tensor->ne[0]; + auto row_size = ggml_row_size(tensor->type, n_per_row); + std::vector qtmp(r.num_rows*row_size); + auto data = (char *)tensor->data; + while (true) { + int chunk = counter.fetch_add(1); + if (chunk >= num_chunks) break; + int first_row = chunk*chunkSize*r.num_rows; + int last_row = std::min(first_row + chunkSize*r.num_rows, nrows); + for (int row = first_row; row < last_row; row += r.num_rows) { + std::memcpy(qtmp.data(), data + row*row_size, r.num_rows*row_size); + r.repack(r.num_rows, n_per_row, qtmp.data(), data + row*row_size); + } + } + }; + std::vector workers(nthread-1); + for (auto& w : workers) w = std::thread(compute); + compute(); + for (auto& w : workers) w.join(); + + tensor->type = r.new_type; +} + +void dequantize_row_ms_i2s(const void * vx, float * y, int64_t k) { + constexpr int kBlockSize = 128; + constexpr int kGroupSize = kBlockSize/4; + GGML_ASSERT(k % kBlockSize == 0); + const uint8_t * x = (const uint8_t *)vx; + const float * dptr = (const float *)(x + k/4); + const float d = dptr[0]; + int nb = k/kBlockSize; + for (int ib = 0; ib < nb; ++ib) { + for (int ig = 0; ig < kBlockSize/kGroupSize; ++ig) { + int shift = 6 - 2*ig; + for (int j = 0; j < kGroupSize; ++j) { + y[j] = d * (((x[j] >> shift) & 3) - 1); + } + y += kGroupSize; + } + x += kGroupSize; + } +} namespace { #ifdef __AVX2__ -static inline float hsum_float_4(__m128 x) { - x = _mm_add_ps(x, _mm_movehl_ps(x, x)); - x = _mm_add_ss(x, _mm_movehdup_ps(x)); - return _mm_cvtss_f32(x); -} -static inline float hsum_float_8(__m256 x) { - return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1))); -} __m128 hsum_float_4x4(__m128 * accm) { accm[0] = _mm_add_ps(_mm_unpacklo_ps(accm[0], accm[2]), _mm_unpackhi_ps(accm[0], accm[2])); accm[1] = _mm_add_ps(_mm_unpacklo_ps(accm[1], accm[3]), _mm_unpackhi_ps(accm[1], accm[3])); @@ -3781,6 +6799,8 @@ std::vector QuantizerIQKT::clus return result; } +// ========================================== iq2_kt ==================================================== + using QuantizerIQ2KT = QuantizerIQKT<32, 8, 16>; const QuantizerIQ2KT& iq2kt_quantizer() { diff --git a/ggml/src/iqk/iqk_quantize.h b/ggml/src/iqk/iqk_quantize.h index 364165e2..46e4d934 100644 --- a/ggml/src/iqk/iqk_quantize.h +++ b/ggml/src/iqk/iqk_quantize.h @@ -79,7 +79,166 @@ size_t quantize_iq4_kt(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst void dequantize_row_iq4_kt(const block_iq4_kt * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void vec_dot_iq4_kt_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void quantize_row_iq4_nl_r4_ref(const float * GGML_RESTRICT x, block_iq4_nl_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_nl_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq4_nl_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq4_nl_r4(const block_iq4_nl_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq4_nl_r4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_q4_0_r4_ref(const float * GGML_RESTRICT x, block_iq4_nl_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_0_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_q4_0_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_q4_0_r4(const block_iq4_nl_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_q4_0_r4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_q8_0_r4_ref(const float * GGML_RESTRICT x, block_q8_0_x4 * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_0_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_q8_0_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_q8_0_r4(const block_q8_0_x4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_q8_0_r4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_q5_0_r4_ref(const float * GGML_RESTRICT x, block_q5_0_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_0_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_q5_0_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_q5_0_r4(const block_q5_0_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_q5_0_r4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_q6_0_r4_ref(const float * GGML_RESTRICT x, block_q6_0_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_q6_0_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_q6_0_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_q6_0_r4(const block_q6_0_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_q6_0_r4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_iq4_xs_r4_ref(const float * GGML_RESTRICT x, block_iq4_xs_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_xs_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq4_xs_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq4_xs_r4(const block_iq4_xs_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq4_xs_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_iq2_bn_ref (const float * GGML_RESTRICT x, block_iq2_bn * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_bn (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq2_bn (const block_iq2_bn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +size_t quantize_iq2_bn (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void vec_dot_iq2_bn_q8_K64(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_iq2_bn_r4_ref (const float * GGML_RESTRICT x, block_iq2_bn * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_bn_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq2_bn_r4(const block_iq2_bn * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +size_t quantize_iq2_bn_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void vec_dot_iq2_bn_r4_q8_K64(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_q3_k_r4_ref(const float * GGML_RESTRICT x, block_q3_k_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_q3_k_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_q3_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_q3_k_r4(const block_q3_k_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_q3_k_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_q2_k_r4_ref(const float * GGML_RESTRICT x, block_q2_k_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_q2_k_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_q2_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_q2_k_r4(const block_q2_k_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_q2_k_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_q4_k_r4_ref(const float * GGML_RESTRICT x, block_q4_k_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_k_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_q4_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_q4_k_r4(const block_q4_k_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_q4_k_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_q5_k_r4_ref(const float * GGML_RESTRICT x, block_q5_k_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_k_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_q5_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_q5_k_r4(const block_q5_k_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_q5_k_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_q6_k_r4_ref(const float * GGML_RESTRICT x, block_q6_k_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_q6_k_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_q6_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_q6_k_r4(const block_q6_k_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_q6_k_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_iq5_k_r4_ref(const float * GGML_RESTRICT x, block_iq5_k_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_iq5_k_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq5_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq5_k_r4(const block_iq5_k_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq5_k_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_iq4_k_r4_ref(const float * GGML_RESTRICT x, block_iq4_k_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_k_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq4_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq4_k_r4(const block_iq4_k_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq4_k_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_iq3_k_r4_ref(const float * GGML_RESTRICT x, block_iq3_k_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_iq3_k_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq3_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq3_k_r4(const block_iq3_k_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq3_k_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_iq2_k_r4_ref(const float * GGML_RESTRICT x, block_iq2_k_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_k_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq2_k_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq2_k_r4(const block_iq2_k_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq2_k_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_iq4_ks_r4_ref(const float * GGML_RESTRICT x, block_iq4_ks_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_ks_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq4_ks_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq4_ks_r4(const block_iq4_ks_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq4_ks_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_iq2_xxs_r4_ref(const float * GGML_RESTRICT x, block_iq2_xxs_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_xxs_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq2_xxs_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq2_xxs_r4(const block_iq2_xxs_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq2_xxs_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_iq2_xs_r4_ref(const float * GGML_RESTRICT x, block_iq2_xs_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_xs_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq2_xs_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq2_xs_r4(const block_iq2_xs_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq2_xs_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_iq2_s_r4_ref(const float * GGML_RESTRICT x, block_iq2_s_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_s_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq2_s_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq2_s_r4(const block_iq2_s_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq2_s_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_iq3_xxs_r4_ref(const float * GGML_RESTRICT x, block_iq3_xxs_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_iq3_xxs_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq3_xxs_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq3_xxs_r4(const block_iq3_xxs_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq3_xxs_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_iq3_s_r4_ref(const float * GGML_RESTRICT x, block_iq3_s_r4 * GGML_RESTRICT y, int64_t k); +void quantize_row_iq3_s_r4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_iq3_s_r4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_iq3_s_r4(const block_iq3_s_r4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_iq3_s_r4_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void quantize_row_q8_k_r8_ref(const float * GGML_RESTRICT x, block_q8_k_r8 * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_k_r8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +size_t quantize_q8_k_r8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +void dequantize_row_q8_k_r8(const block_q8_k_r8 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void vec_dot_q8_k_r8_q8_k(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void iqk_quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void quantize_row_q8_K64_ref(const float * GGML_RESTRICT x, block_q8_K64 * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_K64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_K16(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_K32(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_KR8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_0_x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_1_x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); + +void repack_f32_bf16_r16 (const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row); +void repack_bf16_bf16_r16(const void * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row); + +void iqk_repack_tensor(struct ggml_tensor * tensor); + +// So we can re-pack Microsoft's BitNet I2_S quants +void dequantize_row_ms_i2s(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); #ifdef __cplusplus } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1bea66aa..90d5efec 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -89,6 +89,8 @@ class Keys: EXPERT_USED_COUNT = "{arch}.expert_used_count" EXPERT_SHARED_COUNT = "{arch}.expert_shared_count" EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale" + EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm" + EXPERT_GATING_FUNC = "{arch}.expert_gating_func" POOLING_TYPE = "{arch}.pooling_type" LOGIT_SCALE = "{arch}.logit_scale" DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" @@ -257,6 +259,7 @@ class MODEL_TENSOR(IntEnum): FFN_GATE_SHEXP = auto() FFN_DOWN_SHEXP = auto() FFN_UP_SHEXP = auto() + FFN_EXP_PROBS_B = auto() ATTN_Q_NORM = auto() ATTN_K_NORM = auto() LAYER_OUT_NORM = auto() @@ -387,6 +390,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", + MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", @@ -978,6 +982,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_GATE_SHEXP, MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B ], MODEL_ARCH.CHATGLM : [ MODEL_TENSOR.TOKEN_EMBD, @@ -1177,6 +1182,10 @@ class GGMLQuantizationType(IntEnum): IQ2_TN = 42, +class ExpertGatingFuncType(IntEnum): + SOFTMAX = 1 + SIGMOID = 2 + # TODO: add GGMLFileType from ggml_ftype in ggml.h diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 76385a82..e31bf97b 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -26,6 +26,7 @@ from .constants import ( RopeScalingType, PoolingType, TokenType, + ExpertGatingFuncType, ) from .quants import quant_shape_from_byte_shape @@ -670,6 +671,12 @@ class GGUFWriter: def add_expert_weights_scale(self, value: float) -> None: self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value) + def add_expert_weights_norm(self, value: bool) -> None: + self.add_bool(Keys.LLM.EXPERT_WEIGHTS_NORM.format(arch=self.arch), value) + + def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None: + self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value) + def add_layer_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 9aa2209e..a70b69c5 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -251,6 +251,10 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe ), + MODEL_TENSOR.FFN_EXP_PROBS_B: ( + "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 + ), + # Feed-forward up MODEL_TENSOR.FFN_UP: ( "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox diff --git a/include/llama.h b/include/llama.h index 4571c4ff..a7b2dff7 100644 --- a/include/llama.h +++ b/include/llama.h @@ -93,6 +93,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 23, //llama.cpp lists this as 28 }; // note: these values should be synchronized with ggml_rope @@ -182,6 +183,31 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ2_KT = 149, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ3_KT = 150, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ4_KT = 151, // except 1d tensors + // + LLAMA_FTYPE_MOSTLY_Q4_0_R4 = 202, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q8_0_R4 = 207, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_0_R4 = 208, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q2_K_R4 = 210, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_R4 = 211, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_K_R4 = 214, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_K_R4 = 216, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q6_K_R4 = 218, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4 = 219, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_XS_R4 = 220, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 = 223, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 = 225, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ3_S_R4 = 226, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_M_R4 = 229, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ4_XS_R4 = 230, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q6_0_R4 = 335, // except 1d tensors + LLAMA_FTYPE_MOSTLY_BF16_R16 = 232, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_BN_R4 = 337, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ2_K_R4 = 338, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ3_K_R4 = 339, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ4_K_R4 = 340, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ5_K_R4 = 341, // except 1d tensors + LLAMA_FTYPE_MOSTLY_IQ4_KS_R4 = 345, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q8_K_R8 = 399, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; @@ -310,6 +336,7 @@ extern "C" { bool use_mmap; // use mmap if possible bool use_mlock; // force system to keep model in RAM bool check_tensors; // validate model tensor data + bool repack_tensors;// repack if available }; // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations @@ -1022,6 +1049,8 @@ extern "C" { bool add_ass, char * buf, int32_t length); + // Get list of built-in chat templates + LLAMA_API int32_t llama_chat_builtin_templates(const char ** output, size_t len); // // Grammar diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 46a6ad56..3aea013a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -23,6 +23,7 @@ add_library(llama ) target_include_directories(llama PUBLIC . ../include) +target_include_directories(llama PRIVATE ../ggml/src) target_compile_features (llama PUBLIC cxx_std_11) # don't bump target_link_libraries(llama PUBLIC ggml) diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 749f8571..4bd5aa81 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -367,6 +367,13 @@ struct llm_tokenizer_bpe { "\\p{N}+", }; break; + case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM: + regex_exprs = { + "\\p{N}{1,3}", + "[一-龥぀-ゟ゠-ヿ]+", + "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER: regex_exprs = { "[\r\n]", diff --git a/src/llama.cpp b/src/llama.cpp index 1b2a0dfe..0a73bb0e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9,6 +9,9 @@ #include "ggml-alloc.h" #include "ggml-backend.h" +// TODO: fix this include +#include "iqk/iqk_quantize.h" + #ifdef GGML_USE_RPC # include "ggml-rpc.h" #endif @@ -103,7 +106,7 @@ // bump if necessary #define LLAMA_MAX_LAYERS 512 -#define LLAMA_MAX_EXPERTS 160 // DeepSeekV2 +#define LLAMA_MAX_EXPERTS 256 // DeepSeekV2 // // helpers @@ -291,6 +294,8 @@ enum llm_kv { LLM_KV_EXPERT_USED_COUNT, LLM_KV_EXPERT_SHARED_COUNT, LLM_KV_EXPERT_WEIGHTS_SCALE, + LLM_KV_EXPERT_WEIGHTS_NORM, + LLM_KV_EXPERT_GATING_FUNC, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, LLM_KV_DECODER_START_TOKEN_ID, @@ -396,6 +401,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, + { LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" }, + { LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" }, { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, @@ -517,6 +524,7 @@ enum llm_tensor { LLM_TENSOR_FFN_DOWN_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_LAYER_OUT_NORM, @@ -1208,6 +1216,7 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, }, }, { @@ -1355,6 +1364,76 @@ static const std::map> LLM_TENSOR_NA }, }; +enum llm_chat_template { + LLM_CHAT_TEMPLATE_CHATML, + LLM_CHAT_TEMPLATE_LLAMA_2, + LLM_CHAT_TEMPLATE_LLAMA_2_SYS, + LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS, + LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP, + LLM_CHAT_TEMPLATE_MISTRAL_V1, + LLM_CHAT_TEMPLATE_MISTRAL_V3, + LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN, + LLM_CHAT_TEMPLATE_MISTRAL_V7, + LLM_CHAT_TEMPLATE_PHI_3, + LLM_CHAT_TEMPLATE_FALCON_3, + LLM_CHAT_TEMPLATE_ZEPHYR, + LLM_CHAT_TEMPLATE_MONARCH, + LLM_CHAT_TEMPLATE_GEMMA, + LLM_CHAT_TEMPLATE_ORION, + LLM_CHAT_TEMPLATE_OPENCHAT, + LLM_CHAT_TEMPLATE_VICUNA, + LLM_CHAT_TEMPLATE_VICUNA_ORCA, + LLM_CHAT_TEMPLATE_DEEPSEEK, + LLM_CHAT_TEMPLATE_DEEPSEEK_2, + LLM_CHAT_TEMPLATE_DEEPSEEK_3, + LLM_CHAT_TEMPLATE_COMMAND_R, + LLM_CHAT_TEMPLATE_LLAMA_3, + LLM_CHAT_TEMPLATE_CHATGML_3, + LLM_CHAT_TEMPLATE_CHATGML_4, + LLM_CHAT_TEMPLATE_MINICPM, + LLM_CHAT_TEMPLATE_EXAONE_3, + LLM_CHAT_TEMPLATE_RWKV_WORLD, + LLM_CHAT_TEMPLATE_GRANITE, + LLM_CHAT_TEMPLATE_GIGACHAT, + LLM_CHAT_TEMPLATE_MEGREZ, + LLM_CHAT_TEMPLATE_UNKNOWN, +}; + +static const std::map LLM_CHAT_TEMPLATES = { + { "chatml", LLM_CHAT_TEMPLATE_CHATML }, + { "llama2", LLM_CHAT_TEMPLATE_LLAMA_2 }, + { "llama2-sys", LLM_CHAT_TEMPLATE_LLAMA_2_SYS }, + { "llama2-sys-bos", LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS }, + { "llama2-sys-strip", LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP }, + { "mistral-v1", LLM_CHAT_TEMPLATE_MISTRAL_V1 }, + { "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 }, + { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN }, + { "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 }, + { "phi3", LLM_CHAT_TEMPLATE_PHI_3 }, + { "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 }, + { "zephyr", LLM_CHAT_TEMPLATE_ZEPHYR }, + { "monarch", LLM_CHAT_TEMPLATE_MONARCH }, + { "gemma", LLM_CHAT_TEMPLATE_GEMMA }, + { "orion", LLM_CHAT_TEMPLATE_ORION }, + { "openchat", LLM_CHAT_TEMPLATE_OPENCHAT }, + { "vicuna", LLM_CHAT_TEMPLATE_VICUNA }, + { "vicuna-orca", LLM_CHAT_TEMPLATE_VICUNA_ORCA }, + { "deepseek", LLM_CHAT_TEMPLATE_DEEPSEEK }, + { "deepseek2", LLM_CHAT_TEMPLATE_DEEPSEEK_2 }, + { "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 }, + { "command-r", LLM_CHAT_TEMPLATE_COMMAND_R }, + { "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 }, + { "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 }, + { "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 }, + { "minicpm", LLM_CHAT_TEMPLATE_MINICPM }, + { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 }, + { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, + { "granite", LLM_CHAT_TEMPLATE_GRANITE }, + { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, + { "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, +}; + + static llm_arch llm_arch_from_string(const std::string & name) { for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT if (kv.second == name) { @@ -2183,6 +2262,7 @@ enum e_model { MODEL_70B, MODEL_236B, MODEL_314B, + MODEL_671B, MODEL_SMALL, MODEL_MEDIUM, MODEL_LARGE, @@ -2200,6 +2280,21 @@ static const size_t kiB = 1024; static const size_t MiB = 1024*kiB; static const size_t GiB = 1024*MiB; +enum llm_expert_gating_func_type { + LLM_EXPERT_GATING_FUNC_SOFTMAX = 1, + LLM_EXPERT_GATING_FUNC_SIGMOID = 2, +}; + +static const char * llama_expert_gating_func_name(llm_expert_gating_func_type type) { + switch (type) { + case LLM_EXPERT_GATING_FUNC_SOFTMAX: return "softmax"; + case LLM_EXPERT_GATING_FUNC_SIGMOID: return "sigmoid"; + default: return "unknown"; + } +} + + + struct llama_hparams { bool vocab_only; bool rope_finetuned; @@ -2229,6 +2324,8 @@ struct llama_hparams { uint32_t n_ff_shexp = 0; uint32_t n_expert_shared = 0; float expert_weights_scale = 0.0; + bool expert_weights_norm = false; + uint32_t expert_gating_func = LLM_EXPERT_GATING_FUNC_SOFTMAX; float f_norm_eps; float f_norm_rms_eps; @@ -2499,6 +2596,7 @@ struct llama_layer { struct ggml_tensor * ffn_down_b = nullptr; // b2 struct ggml_tensor * ffn_up_b = nullptr; // b3 struct ggml_tensor * ffn_act; + struct ggml_tensor * ffn_exp_probs_b = nullptr; // mamba proj struct ggml_tensor * ssm_in; @@ -3653,6 +3751,7 @@ struct llama_model_loader { bool use_mmap = false; bool check_tensors; + bool repack_tensors = false; llama_files files; llama_ftype ftype; @@ -3686,7 +3785,7 @@ struct llama_model_loader { std::string arch_name; LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); - llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) { + llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, const struct llama_model_kv_override * param_overrides_p) { int trace = 0; if (getenv("LLAMA_TRACE")) { trace = atoi(getenv("LLAMA_TRACE")); @@ -3828,6 +3927,7 @@ struct llama_model_loader { case GGML_TYPE_F32: ftype = LLAMA_FTYPE_ALL_F32; break; case GGML_TYPE_F16: ftype = LLAMA_FTYPE_MOSTLY_F16; break; case GGML_TYPE_BF16: ftype = LLAMA_FTYPE_MOSTLY_BF16; break; + case GGML_TYPE_BF16_R16:ftype = LLAMA_FTYPE_MOSTLY_BF16_R16;break; case GGML_TYPE_Q4_0: ftype = LLAMA_FTYPE_MOSTLY_Q4_0; break; case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break; case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break; @@ -3836,31 +3936,53 @@ struct llama_model_loader { case GGML_TYPE_Q8_0: ftype = LLAMA_FTYPE_MOSTLY_Q8_0; break; case GGML_TYPE_Q2_K: ftype = LLAMA_FTYPE_MOSTLY_Q2_K; break; case GGML_TYPE_Q3_K: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M; break; + case GGML_TYPE_Q3_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_R4; break; case GGML_TYPE_Q4_K: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M; break; + case GGML_TYPE_Q4_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_R4; break; case GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break; + case GGML_TYPE_Q5_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_R4; break; case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break; + case GGML_TYPE_Q6_K_R4: ftype = LLAMA_FTYPE_MOSTLY_Q6_K_R4; break; + case GGML_TYPE_Q8_K_R8: ftype = LLAMA_FTYPE_MOSTLY_Q8_K_R8; break; case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break; + case GGML_TYPE_IQ2_XXS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4; break; case GGML_TYPE_IQ2_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS; break; + case GGML_TYPE_IQ2_XS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS_R4; break; case GGML_TYPE_IQ2_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_KS; break; - case GGML_TYPE_IQ2_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ2_KT; break; - case GGML_TYPE_IQ2_S: ftype = LLAMA_FTYPE_MOSTLY_IQ2_S; break; + case GGML_TYPE_IQ2_S: ftype = LLAMA_FTYPE_MOSTLY_IQ2_M; break; + case GGML_TYPE_IQ2_S_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_M_R4;break; case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break; + case GGML_TYPE_IQ3_XXS_R4: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4; break; + case GGML_TYPE_IQ2_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ2_KT; break; case GGML_TYPE_IQ3_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ3_KT; break; case GGML_TYPE_IQ4_KT: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KT; break; case GGML_TYPE_IQ1_S: ftype = LLAMA_FTYPE_MOSTLY_IQ1_S; break; case GGML_TYPE_IQ1_M: ftype = LLAMA_FTYPE_MOSTLY_IQ1_M; break; case GGML_TYPE_IQ1_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ1_BN; break; case GGML_TYPE_IQ2_BN: ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN; break; + case GGML_TYPE_IQ2_BN_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_BN_R4;break; case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; + case GGML_TYPE_IQ4_NL_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL_R4;break; + case GGML_TYPE_IQ4_XS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS_R4;break; + case GGML_TYPE_Q4_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_R4; break; + case GGML_TYPE_Q5_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q5_0_R4; break; + case GGML_TYPE_Q6_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q6_0_R4; break; + case GGML_TYPE_Q8_0_R4: ftype = LLAMA_FTYPE_MOSTLY_Q8_0_R4; break; case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case GGML_TYPE_IQ4_KS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS; break; + case GGML_TYPE_IQ4_KS_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_KS_R4; break; case GGML_TYPE_IQ4_KSS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_KSS; break; case GGML_TYPE_IQ2_K: ftype = LLAMA_FTYPE_MOSTLY_IQ2_K; break; + case GGML_TYPE_IQ2_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ2_K_R4;break; case GGML_TYPE_IQ3_K: ftype = LLAMA_FTYPE_MOSTLY_IQ3_K; break; + case GGML_TYPE_IQ3_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ3_K_R4;break; case GGML_TYPE_IQ4_K: ftype = LLAMA_FTYPE_MOSTLY_IQ4_K; break; + case GGML_TYPE_IQ4_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ4_K_R4;break; case GGML_TYPE_IQ5_K: ftype = LLAMA_FTYPE_MOSTLY_IQ5_K; break; + case GGML_TYPE_IQ5_K_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ5_K_R4;break; case GGML_TYPE_IQ6_K: ftype = LLAMA_FTYPE_MOSTLY_IQ6_K; break; case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; + case GGML_TYPE_IQ3_S_R4:ftype = LLAMA_FTYPE_MOSTLY_IQ3_S_R4;break; case GGML_TYPE_Q4_0_4_4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_4; break; case GGML_TYPE_Q4_0_4_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_8; break; case GGML_TYPE_Q4_0_8_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_8_8; break; @@ -3915,9 +4037,13 @@ struct llama_model_loader { LLAMA_LOG_WARN("%s: mmap is not supported on this platform\n", __func__); use_mmap = false; } + if (repack_tensors) { + use_mmap = false; + } this->use_mmap = use_mmap; this->check_tensors = check_tensors; + this->repack_tensors = repack_tensors; } ~llama_model_loader() { @@ -4530,6 +4656,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_ALL_F32: return "all F32"; case LLAMA_FTYPE_MOSTLY_F16: return "F16"; case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; + case LLAMA_FTYPE_MOSTLY_BF16_R16: return "BF16_R16"; case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; @@ -4537,40 +4664,63 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q6_0: return "Q6_0"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q2_K_R4: return "Q2_K_R4"; case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large"; + case LLAMA_FTYPE_MOSTLY_Q3_K_R4: return "Q3_K_R4"; case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q4_K_R4: return "Q4_K_R4"; case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q5_K_R4: return "Q5_K_R4"; case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; + case LLAMA_FTYPE_MOSTLY_Q6_K_R4: return "Q6_K_R4"; + case LLAMA_FTYPE_MOSTLY_Q8_K_R8: return "Q8_K_R8"; case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4:return "IQ2_XXS_R4 - 2.0625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_XS_R4:return "IQ2_XS_R4 - 2.3125 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_KS: return "IQ2_KS - 2.1875 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ2_KT: return "IQ2_KT - 2.125 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_M: return "IQ2_M - 2.7 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_M_R4: return "IQ2_M_R4 - 2.7 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return "IQ3_XXS - 3.0625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_KT: return "IQ2_KT - 2.125 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_KT: return "IQ3_KT - 3.125 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_KT: return "IQ4_KT - 4.0 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4: return "IQ3_XXS_R4 - 3.0625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_S: return "IQ1_S - 1.5625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_NL_R4:return "IQ4_NL_R4 - 4.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_XS_R4:return "IQ4_XS_R4 - 4.25 bpw"; + case LLAMA_FTYPE_MOSTLY_Q4_0_R4: return "Q4_0_R4 - 4.5 bpw"; + case LLAMA_FTYPE_MOSTLY_Q5_0_R4: return "Q5_0_R4 - 5.5 bpw"; + case LLAMA_FTYPE_MOSTLY_Q6_0_R4: return "Q6_0_R4 - 6.5 bpw"; + case LLAMA_FTYPE_MOSTLY_Q8_0_R4: return "Q8_0_R4 - 8.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_KS: return "IQ4_KS - 4.25 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_KS_R4:return "IQ4_KS_R4 - 4.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_KSS: return "IQ4_KSS - 4.0 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_K: return "IQ2_K - 2.375 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_K_R4: return "IQ2_K_R4 - 2.375 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_K: return "IQ3_K - 3.4325 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_K_R4: return "IQ3_K_R4 - 3.4325 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_KL: return "IQ3_KL - 4 bpw"; case LLAMA_FTYPE_MOSTLY_IQ4_K: return "IQ4_K - 4.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_K_R4: return "IQ4_K_R4 - 4.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ5_K: return "IQ5_K - 5.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ5_K_R4: return "IQ5_K_R4 - 5.5 bpw"; case LLAMA_FTYPE_MOSTLY_IQ6_K: return "IQ6_K - 6.6 bpw"; case LLAMA_FTYPE_MOSTLY_IQ1_BN: return "IQ1_BN - 1.625 bpw Bitnet"; case LLAMA_FTYPE_MOSTLY_IQ2_BN: return "IQ2_BN - 2.00 bpw Bitnet"; + case LLAMA_FTYPE_MOSTLY_IQ2_BN_R4:return "IQ2_BN_R4 - 2.00 bpw Bitnet"; case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_S_R4: return "IQ3_S_R4 - 3.4375 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: return "Q4_0_4_4"; case LLAMA_FTYPE_MOSTLY_Q4_0_4_8: return "Q4_0_4_8"; @@ -4628,6 +4778,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_70B: return "70B"; case MODEL_236B: return "236B"; case MODEL_314B: return "314B"; + case MODEL_671B: return "671B"; case MODEL_SMALL: return "0.1B"; case MODEL_MEDIUM: return "0.4B"; case MODEL_LARGE: return "0.8B"; @@ -5253,11 +5404,19 @@ static void llm_load_hparams( ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == 0) { + // for compatibility with existing DeepSeek V2 and V2.5 GGUFs + // that have no expert_gating_func model parameter set + hparams.expert_gating_func = LLM_EXPERT_GATING_FUNC_SOFTMAX; + } ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); switch (hparams.n_layer) { case 27: model.type = e_model::MODEL_16B; break; case 60: model.type = e_model::MODEL_236B; break; + case 61: model.type = e_model::MODEL_671B; break; default: model.type = e_model::MODEL_UNKNOWN; } } break; @@ -5503,7 +5662,8 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "llama3" || tokenizer_pre == "llama-v3" || - tokenizer_pre == "llama-bpe") { + tokenizer_pre == "llama-bpe"|| + tokenizer_pre == "falcon3") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3; vocab.tokenizer_ignore_merges = true; vocab.tokenizer_add_bos = true; @@ -5515,6 +5675,10 @@ static void llm_load_vocab( tokenizer_pre == "deepseek-coder") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER; vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "deepseek-v3") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM; + vocab.tokenizer_clean_spaces = false; } else if ( tokenizer_pre == "falcon") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON; @@ -5541,7 +5705,7 @@ static void llm_load_vocab( vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R; vocab.tokenizer_clean_spaces = false; } else if ( - tokenizer_pre == "qwen2") { + tokenizer_pre == "qwen2" || tokenizer_pre == "deepseek-r1-qwen") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2; vocab.tokenizer_clean_spaces = false; } else if ( @@ -6025,6 +6189,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((enum llm_expert_gating_func_type) hparams.expert_gating_func)); LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); } @@ -7490,6 +7656,7 @@ static bool llm_load_tensors( layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } else { layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + layer.ffn_exp_probs_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert} ); GGML_ASSERT(n_expert > 0); GGML_ASSERT(n_expert_used > 0); @@ -7853,6 +8020,19 @@ static bool llm_load_tensors( } } + if (!ml.use_mmap && ml.repack_tensors) { + int n_repacked = 0; + for (auto& it : model.tensors_by_name) { + if (ggml_backend_buffer_is_host(it.second->buffer)) { + auto orig_type = it.second->type; + iqk_repack_tensor(it.second); + if (it.second->type != orig_type) ++n_repacked; + //printf("Repacking tensor %s\n", it.first.c_str()); + } + } + if (n_repacked > 0) printf("============ Repacked %d tensors\n", n_repacked); + } + if (model.arch == LLM_ARCH_BITNET) { auto set_scale = [] (ggml_tensor * w, ggml_tensor * s) { if (!s) { @@ -7888,7 +8068,7 @@ static bool llm_load_tensors( // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) { try { - llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides); + llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.repack_tensors, params.kv_overrides); model.hparams.vocab_only = params.vocab_only; @@ -8283,12 +8463,14 @@ static struct ggml_tensor * llm_build_moe_ffn( struct ggml_tensor * up_exps, struct ggml_tensor * gate_exps, struct ggml_tensor * down_exps, + struct ggml_tensor * exp_probs_b, int64_t n_expert, int64_t n_expert_used, llm_ffn_op_type type_op, bool norm_w, bool scale_w, float w_scale, +llm_expert_gating_func_type gating_op, const llm_build_cb & cb, int il) { int64_t n_embd = cur->ne[0]; @@ -8297,11 +8479,32 @@ static struct ggml_tensor * llm_build_moe_ffn( ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens] cb(logits, "ffn_moe_logits", il); - ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens] + //ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens] + ggml_tensor * probs = nullptr; + switch (gating_op) { + case LLM_EXPERT_GATING_FUNC_SOFTMAX: + { + probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens] + } break; + case LLM_EXPERT_GATING_FUNC_SIGMOID: + { + probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens] + } break; + default: + GGML_ABORT("fatal error"); + } cb(probs, "ffn_moe_probs", il); + // add experts selection bias - introduced in DeepSeek V3 + // leave probs unbiased as it's later used to get expert weights + ggml_tensor * selection_probs = probs; + if (exp_probs_b != nullptr) { + selection_probs = ggml_add(ctx, probs, exp_probs_b); + cb(selection_probs, "ffn_moe_probs_biased", il); + } + // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] + ggml_tensor * selected_experts = ggml_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens] cb(selected_experts->src[0], "ffn_moe_argsort", il); cb(selected_experts, "ffn_moe_topk", il); @@ -9117,9 +9320,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, + LLM_EXPERT_GATING_FUNC_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); } @@ -9610,9 +9815,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_GELU, true, false, 0.0, + LLM_EXPERT_GATING_FUNC_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); @@ -9751,9 +9958,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, + LLM_EXPERT_GATING_FUNC_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); @@ -10881,9 +11090,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_SILU, false, false, 0.0, + LLM_EXPERT_GATING_FUNC_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); @@ -13046,9 +13257,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, + LLM_EXPERT_GATING_FUNC_SOFTMAX, cb, il); cb(cur, "ffn_moe_out", il); @@ -13261,9 +13474,11 @@ struct llm_build_context { model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, - LLM_FFN_SILU, false, + LLM_FFN_SILU, hparams.expert_weights_norm, true, hparams.expert_weights_scale, + (enum llm_expert_gating_func_type) hparams.expert_gating_func, cb, il); cb(moe_out, "ffn_moe_out", il); @@ -15635,6 +15850,12 @@ static void llama_tensor_dequantize_internal( throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type))); } + if (tensor->type == GGML_TYPE_I2_S) { + // we need to dequantize the entire tensor for I2_S + qtype.to_float(tensor->data, f32_output, nelements); + return; + } + if (nthread < 2) { if (tensor->type == GGML_TYPE_F16) { ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements); @@ -15740,14 +15961,21 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M || ftype == LLAMA_FTYPE_MOSTLY_IQ2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K || - ftype == LLAMA_FTYPE_MOSTLY_IQ2_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_KT || ftype == LLAMA_FTYPE_MOSTLY_IQ3_KT) { + ftype == LLAMA_FTYPE_MOSTLY_IQ2_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K_R4 || + ftype == LLAMA_FTYPE_MOSTLY_IQ2_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 || + ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M_R4 || + ftype == LLAMA_FTYPE_MOSTLY_IQ2_KT || ftype == LLAMA_FTYPE_MOSTLY_IQ3_KT) { new_type = !qs.has_output ? GGML_TYPE_IQ4_K : GGML_TYPE_Q5_K; } - else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_S || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || - ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS) && !qs.has_output) { + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS_R4) { + new_type = !qs.has_output ? GGML_TYPE_IQ4_K_R4 : GGML_TYPE_Q5_K_R4; + } + else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_S || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S_R4 || + ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS_R4) && !qs.has_output) { new_type = GGML_TYPE_IQ5_K; } - else if (new_type != GGML_TYPE_Q8_0 && new_type != GGML_TYPE_IQ6_K) { + else if (new_type != GGML_TYPE_Q8_0 && new_type != GGML_TYPE_Q8_0_R4 && new_type != GGML_TYPE_IQ6_K && new_type != GGML_TYPE_Q6_K_R4 && + new_type != GGML_TYPE_Q8_K_R8) { new_type = GGML_TYPE_Q6_K; } } @@ -15756,44 +15984,104 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type = qs.params->token_embedding_type; } else { if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || - ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) { + ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M || + ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS_R4) { new_type = GGML_TYPE_Q2_K; } - else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) { + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M_R4) { new_type = GGML_TYPE_IQ3_S; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_KT) { new_type = GGML_TYPE_IQ3_S; } - else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_BN || ftype == LLAMA_FTYPE_MOSTLY_IQ2_BN) { + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4) { + new_type = GGML_TYPE_IQ3_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_BN || ftype == LLAMA_FTYPE_MOSTLY_IQ2_BN || ftype == LLAMA_FTYPE_MOSTLY_IQ2_BN_R4) { new_type = GGML_TYPE_IQ4_NL; } else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 || new_type == GGML_TYPE_Q4_0_8_8) { new_type = GGML_TYPE_Q4_0; } + else if (new_type == GGML_TYPE_IQ4_NL_R4) { + new_type = GGML_TYPE_IQ4_NL; + } + else if (new_type == GGML_TYPE_IQ4_XS_R4) { + new_type = GGML_TYPE_IQ4_XS; + } + else if (new_type == GGML_TYPE_Q2_K_R4) { + new_type = GGML_TYPE_Q2_K; + } + else if (new_type == GGML_TYPE_Q3_K_R4) { + new_type = GGML_TYPE_Q3_K; + } + else if (new_type == GGML_TYPE_Q4_K_R4) { + new_type = GGML_TYPE_Q4_K; + } + else if (new_type == GGML_TYPE_Q5_K_R4) { + new_type = GGML_TYPE_Q5_K; + } + else if (new_type == GGML_TYPE_Q6_K_R4) { + new_type = GGML_TYPE_Q6_K; + } + else if (new_type == GGML_TYPE_Q8_K_R8) { + new_type = GGML_TYPE_Q8_0; + } + else if (new_type == GGML_TYPE_IQ2_K_R4) { + new_type = GGML_TYPE_IQ2_K; + } + else if (new_type == GGML_TYPE_IQ3_K_R4) { + new_type = GGML_TYPE_IQ3_K; + } + else if (new_type == GGML_TYPE_IQ3_S_R4) { + new_type = GGML_TYPE_IQ3_S; + } + else if (new_type == GGML_TYPE_IQ4_K_R4) { + new_type = GGML_TYPE_IQ4_K; + } + else if (new_type == GGML_TYPE_IQ5_K_R4) { + new_type = GGML_TYPE_IQ5_K; + } + else if (new_type == GGML_TYPE_IQ4_KS_R4) { + new_type = GGML_TYPE_IQ4_KS; + } + else if (new_type == GGML_TYPE_Q4_0_R4) { + new_type = GGML_TYPE_Q4_0; + } + else if (new_type == GGML_TYPE_Q5_0_R4) { + new_type = GGML_TYPE_Q5_0; + } + else if (new_type == GGML_TYPE_Q6_0_R4) { + new_type = GGML_TYPE_Q6_0; + } + else if (new_type == GGML_TYPE_Q8_0_R4) { + new_type = GGML_TYPE_Q8_0; + } + else if (new_type == GGML_TYPE_BF16_R16) { + new_type = GGML_TYPE_BF16; + } } } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || - ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M || - ftype == LLAMA_FTYPE_MOSTLY_IQ2_KS|| ftype == LLAMA_FTYPE_MOSTLY_IQ2_KT) { + ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M || + ftype == LLAMA_FTYPE_MOSTLY_IQ2_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS_R4 || + ftype == LLAMA_FTYPE_MOSTLY_IQ2_KT || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M_R4) { + bool is_iq2_m = ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M_R4; if (name.find("attn_v.weight") != std::string::npos) { if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_IQ4_K; else if (qs.model.hparams.n_gqa() >= 2 || qs.model.hparams.n_expert >= 2) new_type = GGML_TYPE_IQ3_K; - //else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_KT) new_type = GGML_TYPE_IQ3_KT; - else { - new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K; - } + else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || is_iq2_m ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K; ++qs.i_attention_wv; } else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) { new_type = GGML_TYPE_Q4_K; } else if (name.find("attn_qkv.weight") != std::string::npos) { - new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_XXS : GGML_TYPE_IQ2_K; + new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || is_iq2_m ? GGML_TYPE_IQ3_XXS : GGML_TYPE_IQ2_K; } else if (name.find("ffn_down") != std::string::npos) { // && ftype != LLAMA_FTYPE_MOSTLY_IQ2_KT) { if (qs.i_ffn_down < qs.n_ffn_down/8) { - new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K; + new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || is_iq2_m ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K; } ++qs.i_ffn_down; } @@ -15802,7 +16090,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type = GGML_TYPE_Q5_K; } else { if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) new_type = GGML_TYPE_IQ3_S; + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || is_iq2_m) new_type = GGML_TYPE_IQ3_S; } } } else if (name.find("attn_v.weight") != std::string::npos) { @@ -15813,9 +16101,15 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_K) { new_type = qs.model.hparams.n_gqa() >= 2 ? GGML_TYPE_IQ4_K : GGML_TYPE_IQ3_K; } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_K_R4) { + new_type = qs.model.hparams.n_gqa() >= 2 ? GGML_TYPE_IQ4_K_R4 : GGML_TYPE_IQ3_K_R4; + } else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && qs.model.hparams.n_gqa() >= 4) { new_type = GGML_TYPE_Q4_K; } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_R4 && qs.model.hparams.n_gqa() >= 4) { + new_type = GGML_TYPE_Q4_K_R4; + } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : qs.model.hparams.n_gqa() >= 2 ? GGML_TYPE_IQ3_K : !qs.has_imatrix ? GGML_TYPE_IQ3_S : GGML_TYPE_IQ3_XXS; @@ -15830,12 +16124,22 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n // : !qs.has_imatrix ? GGML_TYPE_IQ4_KS : GGML_TYPE_IQ4_KT; new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_IQ5_K : GGML_TYPE_IQ4_K; } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4) { + new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K_R4 : qs.model.hparams.n_gqa() >= 2 ? GGML_TYPE_IQ3_K_R4 + : !qs.has_imatrix ? GGML_TYPE_IQ3_K_R4 : GGML_TYPE_IQ3_XXS_R4; + } else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S) && qs.model.hparams.n_gqa() >= 2) { new_type = GGML_TYPE_IQ4_K; } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_S_R4 && qs.model.hparams.n_gqa() >= 2) { + new_type = GGML_TYPE_IQ4_K_R4; + } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_K && qs.model.hparams.n_gqa() >= 2) { new_type = GGML_TYPE_IQ4_K; } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_K_R4 && qs.model.hparams.n_gqa() >= 2) { + new_type = GGML_TYPE_IQ4_K_R4; + } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_KL) { new_type = qs.model.hparams.n_gqa() >= 2 ? GGML_TYPE_IQ5_K : GGML_TYPE_IQ4_K; } @@ -15847,20 +16151,23 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n } else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || + ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS) && qs.model.hparams.n_gqa() >= 2) { new_type = GGML_TYPE_IQ5_K; } - else if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS) { - //new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_IQ5_K : qs.model.hparams.n_gqa() >= 2 ? GGML_TYPE_IQ4_K - // : !qs.has_imatrix ? GGML_TYPE_IQ4_KS : GGML_TYPE_IQ4_KT; - new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_IQ5_K : GGML_TYPE_IQ4_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS_R4 && qs.model.hparams.n_gqa() >= 2) { + new_type = GGML_TYPE_IQ5_K_R4; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_K && qs.model.hparams.n_gqa() >= 2) { new_type = GGML_TYPE_IQ5_K; } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_K_R4 && qs.model.hparams.n_gqa() >= 2) { + new_type = GGML_TYPE_IQ5_K; + } else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_Q6_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_R4 && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) { if (qs.model.hparams.n_vocab >= 127999 && (qs.model.type == MODEL_8B || qs.model.type == MODEL_70B)) new_type = GGML_TYPE_Q6_K; @@ -15878,9 +16185,14 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n } else if (qs.model.hparams.n_gqa() >= 4) { if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_IQ3_XXS) new_type = GGML_TYPE_IQ3_S; + else if (new_type == GGML_TYPE_Q2_K_R4 || new_type == GGML_TYPE_IQ3_XXS_R4) new_type = GGML_TYPE_IQ3_K_R4; else if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_IQ3_S ) new_type = GGML_TYPE_Q4_K; + else if (new_type == GGML_TYPE_IQ3_S_R4) new_type = GGML_TYPE_Q4_K_R4; + else if (new_type == GGML_TYPE_Q3_K_R4) new_type = GGML_TYPE_Q4_K_R4; else if (new_type == GGML_TYPE_Q4_K || new_type == GGML_TYPE_IQ4_XS) new_type = GGML_TYPE_Q5_K; else if (new_type == GGML_TYPE_IQ4_NL) new_type = GGML_TYPE_Q5_K; + else if (new_type == GGML_TYPE_IQ4_NL_R4) new_type = GGML_TYPE_Q5_K; + else if (new_type == GGML_TYPE_IQ4_XS_R4) new_type = GGML_TYPE_Q5_K; else if (new_type == GGML_TYPE_Q5_K) new_type = GGML_TYPE_Q6_K; } ++qs.i_attention_wv; @@ -15894,7 +16206,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) { new_type = GGML_TYPE_IQ3_XXS; } - else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4) { new_type = GGML_TYPE_IQ2_S; } } else if (name.find("attn_q.weight") != std::string::npos) { @@ -15902,7 +16214,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) { new_type = GGML_TYPE_IQ3_XXS; } - else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4) { new_type = GGML_TYPE_IQ2_S; } else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) { @@ -15917,12 +16229,18 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) { if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K; } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_R4) { + if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K_R4; + } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS && !qs.has_imatrix) { new_type = i_layer < n_layer/8 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_KT && !qs.has_imatrix) { new_type = i_layer < n_layer/8 ? GGML_TYPE_IQ4_K : GGML_TYPE_IQ3_K; } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 && !qs.has_imatrix) { + new_type = i_layer < n_layer/8 ? GGML_TYPE_Q4_K_R4 : GGML_TYPE_IQ3_K_R4; + } else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { new_type = i_layer < n_layer/16 ? GGML_TYPE_Q5_K : arch != LLM_ARCH_FALCON || use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q4_K @@ -15946,14 +16264,22 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n if (use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K; } } - else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || - ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS) && !qs.has_imatrix) { + else if (i_layer < n_layer/8 && !qs.has_imatrix && + (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || + ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_KSS || + ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS_R4)) { new_type = GGML_TYPE_Q5_K; } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ4_KS_R4 && i_layer < n_layer/8 && !qs.has_imatrix) { + new_type = GGML_TYPE_Q5_K_R4; + } else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && arch != LLM_ARCH_FALCON && i_layer < n_layer/8) { new_type = GGML_TYPE_Q5_K; } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_R4 && arch != LLM_ARCH_FALCON && i_layer < n_layer/8) { + new_type = GGML_TYPE_Q5_K; + } else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || ftype == LLAMA_FTYPE_MOSTLY_Q5_0) && qs.has_imatrix && i_layer < n_layer/8) { // Guard against craziness in the first few ffn_down layers that can happen even with imatrix for Q4_0/Q5_0. @@ -15961,6 +16287,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n // same quantization as before imatrix stuff, and b) Q4_1/Q5_1 do go crazy on ffn_down without an imatrix. new_type = ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q5_1; } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_0_R4 && qs.has_imatrix && i_layer < n_layer/8) { + new_type = GGML_TYPE_IQ4_NL_R4; + } ++qs.i_ffn_down; } else if (name.find("attn_output.weight") != std::string::npos) { if (qs.params->attn_output_type < GGML_TYPE_COUNT) new_type = qs.params->attn_output_type; @@ -15970,17 +16299,23 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ4_K || - ftype == LLAMA_FTYPE_MOSTLY_IQ2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_KT) { + ftype == LLAMA_FTYPE_MOSTLY_IQ2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_R4 || + ftype == LLAMA_FTYPE_MOSTLY_IQ3_KT || + ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS_R4 || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_R4 || + ftype == LLAMA_FTYPE_MOSTLY_Q2_K_R4|| ftype == LLAMA_FTYPE_MOSTLY_IQ4_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_K_R4 || + ftype == LLAMA_FTYPE_MOSTLY_IQ2_K_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4 || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S_R4) { new_type = GGML_TYPE_Q5_K; } } else { if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K ) new_type = GGML_TYPE_Q3_K; else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) new_type = GGML_TYPE_IQ3_S; else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_KT && qs.model.hparams.n_gqa() >= 4) new_type = GGML_TYPE_IQ3_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4) new_type = GGML_TYPE_IQ3_K_R4; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M ) new_type = GGML_TYPE_Q4_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L ) new_type = GGML_TYPE_Q5_K; else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M ) new_type = GGML_TYPE_IQ4_K; else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_K ) new_type = GGML_TYPE_IQ3_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_K_R4) new_type = GGML_TYPE_IQ3_K_R4; else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_KL ) new_type = GGML_TYPE_IQ4_KS; } } else { @@ -16038,10 +16373,15 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S || new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S || new_type == GGML_TYPE_IQ1_M || new_type == GGML_TYPE_IQ4_K || new_type == GGML_TYPE_IQ2_K || - new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K || - new_type == GGML_TYPE_IQ6_K || new_type == GGML_TYPE_IQ4_KS || new_type == GGML_TYPE_IQ2_KT || - new_type == GGML_TYPE_IQ2_KS || new_type == GGML_TYPE_IQ4_KSS || new_type == GGML_TYPE_IQ3_KT || - new_type == GGML_TYPE_IQ4_KT) { + new_type == GGML_TYPE_IQ2_KT || new_type == GGML_TYPE_IQ3_KT || new_type == GGML_TYPE_IQ4_KT || + new_type == GGML_TYPE_IQ5_K || new_type == GGML_TYPE_IQ3_K || new_type == GGML_TYPE_Q4_K_R4 || + new_type == GGML_TYPE_IQ6_K || new_type == GGML_TYPE_IQ4_KS || new_type == GGML_TYPE_IQ4_XS_R4 || + new_type == GGML_TYPE_IQ2_KS || new_type == GGML_TYPE_IQ4_KSS || new_type == GGML_TYPE_Q6_K_R4 || + new_type == GGML_TYPE_Q5_K_R4 || new_type == GGML_TYPE_Q3_K_R4 || new_type == GGML_TYPE_Q2_K_R4 || + new_type == GGML_TYPE_IQ4_K_R4|| new_type == GGML_TYPE_Q8_K_R8 || new_type == GGML_TYPE_IQ3_K_R4|| + new_type == GGML_TYPE_IQ2_K_R4|| new_type == GGML_TYPE_IQ5_K_R4|| new_type == GGML_TYPE_IQ4_KS_R4 || + new_type == GGML_TYPE_IQ3_XXS_R4 || new_type == GGML_TYPE_IQ2_XXS_R4 || new_type == GGML_TYPE_IQ2_XS_R4 || + new_type == GGML_TYPE_IQ2_S_R4|| new_type == GGML_TYPE_IQ3_S_R4) { int nx = tensor->ne[0]; int ny = tensor->ne[1]; if (nx % QK_K != 0) { @@ -16051,7 +16391,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n ++qs.n_k_quantized; } } - if (new_type == GGML_TYPE_IQ1_BN || new_type == GGML_TYPE_IQ2_BN) { + if (new_type == GGML_TYPE_IQ1_BN || new_type == GGML_TYPE_IQ2_BN || new_type == GGML_TYPE_IQ2_BN_R4) { int nx = tensor->ne[0]; if (nx % QK_IQ1BN != 0) { convert_incompatible_tensor = true; @@ -16060,28 +16400,45 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n if (convert_incompatible_tensor) { switch (new_type) { case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XXS_R4: case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_XS_R4: case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_KT: case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ2_S_R4: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_KT: case GGML_TYPE_IQ4_KT: + case GGML_TYPE_IQ3_XXS_R4: case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: case GGML_TYPE_Q2_K: + case GGML_TYPE_Q2_K_R4: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q3_K_R4: case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ2_K_R4: case GGML_TYPE_IQ3_K: + case GGML_TYPE_IQ3_K_R4: case GGML_TYPE_IQ4_KSS: case GGML_TYPE_IQ4_KS: + case GGML_TYPE_IQ4_KS_R4: + case GGML_TYPE_IQ4_XS_R4: case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ4_K_R4: + case GGML_TYPE_Q4_K_R4: case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; case GGML_TYPE_IQ5_K: - case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break; + case GGML_TYPE_IQ5_K_R4: + case GGML_TYPE_Q5_K_R4: + case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q6_0; break; case GGML_TYPE_IQ6_K: + case GGML_TYPE_Q6_K_R4: + case GGML_TYPE_Q8_K_R8: case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); } @@ -16159,44 +16516,68 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break; case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; + case LLAMA_FTYPE_MOSTLY_BF16_R16: default_type = GGML_TYPE_BF16_R16; break; case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break; // K-quants case LLAMA_FTYPE_MOSTLY_Q2_K_S: case LLAMA_FTYPE_MOSTLY_Q2_K: default_type = GGML_TYPE_Q2_K; break; + case LLAMA_FTYPE_MOSTLY_Q2_K_R4: default_type = GGML_TYPE_Q2_K_R4; break; case LLAMA_FTYPE_MOSTLY_IQ3_XS: default_type = GGML_TYPE_IQ3_S; break; case LLAMA_FTYPE_MOSTLY_Q3_K_S: case LLAMA_FTYPE_MOSTLY_Q3_K_M: case LLAMA_FTYPE_MOSTLY_Q3_K_L: default_type = GGML_TYPE_Q3_K; break; + case LLAMA_FTYPE_MOSTLY_Q3_K_R4: default_type = GGML_TYPE_Q3_K_R4; break; case LLAMA_FTYPE_MOSTLY_Q4_K_S: case LLAMA_FTYPE_MOSTLY_Q4_K_M: default_type = GGML_TYPE_Q4_K; break; + case LLAMA_FTYPE_MOSTLY_Q4_K_R4: default_type = GGML_TYPE_Q4_K_R4; break; case LLAMA_FTYPE_MOSTLY_Q5_K_S: case LLAMA_FTYPE_MOSTLY_Q5_K_M: default_type = GGML_TYPE_Q5_K; break; + case LLAMA_FTYPE_MOSTLY_Q5_K_R4: default_type = GGML_TYPE_Q5_K_R4; break; case LLAMA_FTYPE_MOSTLY_Q6_K: default_type = GGML_TYPE_Q6_K; break; + case LLAMA_FTYPE_MOSTLY_Q6_K_R4: default_type = GGML_TYPE_Q6_K_R4; break; + case LLAMA_FTYPE_MOSTLY_Q8_K_R8: default_type = GGML_TYPE_Q8_K_R8; break; case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break; + case LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4:default_type = GGML_TYPE_IQ2_XXS_R4; break; case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break; + case LLAMA_FTYPE_MOSTLY_IQ2_XS_R4:default_type = GGML_TYPE_IQ2_XS_R4; break; case LLAMA_FTYPE_MOSTLY_IQ2_KS: default_type = GGML_TYPE_IQ2_KS; break; case LLAMA_FTYPE_MOSTLY_IQ2_KT: default_type = GGML_TYPE_IQ2_KT; break; case LLAMA_FTYPE_MOSTLY_IQ2_S: default_type = GGML_TYPE_IQ2_XS; break; case LLAMA_FTYPE_MOSTLY_IQ2_M: default_type = GGML_TYPE_IQ2_S; break; + case LLAMA_FTYPE_MOSTLY_IQ2_M_R4:default_type = GGML_TYPE_IQ2_S_R4;break; case LLAMA_FTYPE_MOSTLY_IQ3_XXS: default_type = GGML_TYPE_IQ3_XXS; break; case LLAMA_FTYPE_MOSTLY_IQ3_KT: default_type = GGML_TYPE_IQ3_KT; break; case LLAMA_FTYPE_MOSTLY_IQ4_KT: default_type = GGML_TYPE_IQ4_KT; break; + case LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4: default_type = GGML_TYPE_IQ3_XXS_R4; break; case LLAMA_FTYPE_MOSTLY_IQ1_S: default_type = GGML_TYPE_IQ1_S; break; case LLAMA_FTYPE_MOSTLY_IQ1_M: default_type = GGML_TYPE_IQ1_M; break; case LLAMA_FTYPE_MOSTLY_IQ1_BN: default_type = GGML_TYPE_IQ1_BN; break; case LLAMA_FTYPE_MOSTLY_IQ2_BN: default_type = GGML_TYPE_IQ2_BN; break; + case LLAMA_FTYPE_MOSTLY_IQ2_BN_R4:default_type = GGML_TYPE_IQ2_BN_R4;break; case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break; + case LLAMA_FTYPE_MOSTLY_IQ4_NL_R4:default_type = GGML_TYPE_IQ4_NL_R4;break; + case LLAMA_FTYPE_MOSTLY_IQ4_XS_R4:default_type = GGML_TYPE_IQ4_XS_R4;break; + case LLAMA_FTYPE_MOSTLY_Q4_0_R4: default_type = GGML_TYPE_Q4_0_R4; break; + case LLAMA_FTYPE_MOSTLY_Q5_0_R4: default_type = GGML_TYPE_Q5_0_R4; break; + case LLAMA_FTYPE_MOSTLY_Q6_0_R4: default_type = GGML_TYPE_Q6_0_R4; break; + case LLAMA_FTYPE_MOSTLY_Q8_0_R4: default_type = GGML_TYPE_Q8_0_R4; break; case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; case LLAMA_FTYPE_MOSTLY_IQ4_KS: default_type = GGML_TYPE_IQ4_KS; break; + case LLAMA_FTYPE_MOSTLY_IQ4_KS_R4:default_type = GGML_TYPE_IQ4_KS_R4;break; case LLAMA_FTYPE_MOSTLY_IQ4_KSS: default_type = GGML_TYPE_IQ4_KSS; break; case LLAMA_FTYPE_MOSTLY_IQ2_K: default_type = GGML_TYPE_IQ2_K; break; + case LLAMA_FTYPE_MOSTLY_IQ2_K_R4:default_type = GGML_TYPE_IQ2_K_R4;break; case LLAMA_FTYPE_MOSTLY_IQ3_K: default_type = GGML_TYPE_IQ3_K; break; + case LLAMA_FTYPE_MOSTLY_IQ3_K_R4:default_type = GGML_TYPE_IQ3_K_R4;break; case LLAMA_FTYPE_MOSTLY_IQ3_KL: default_type = GGML_TYPE_IQ3_K; break; case LLAMA_FTYPE_MOSTLY_IQ4_K: default_type = GGML_TYPE_IQ4_K; break; + case LLAMA_FTYPE_MOSTLY_IQ4_K_R4:default_type = GGML_TYPE_IQ4_K_R4;break; case LLAMA_FTYPE_MOSTLY_IQ5_K: default_type = GGML_TYPE_IQ5_K; break; + case LLAMA_FTYPE_MOSTLY_IQ5_K_R4:default_type = GGML_TYPE_IQ5_K_R4;break; case LLAMA_FTYPE_MOSTLY_IQ6_K: default_type = GGML_TYPE_IQ6_K; break; case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break; + case LLAMA_FTYPE_MOSTLY_IQ3_S_R4:default_type = GGML_TYPE_IQ3_S_R4;break; case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break; case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: default_type = GGML_TYPE_Q4_0_4_4; break; case LLAMA_FTYPE_MOSTLY_Q4_0_4_8: default_type = GGML_TYPE_Q4_0_4_8; break; @@ -16224,7 +16605,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s auto v = (std::vector*)params->kv_overrides; kv_overrides = v->data(); } - llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides); + llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, /* repack_tensors */ false, kv_overrides); ml.init_mappings(false); // no prefetching llama_model model; @@ -16440,6 +16821,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (quantize) { new_type = default_type; + if (new_type == GGML_TYPE_BF16_R16 && strcmp(tensor->name, "token_embd.weight") == 0) { + new_type = GGML_TYPE_BF16; + } // get more optimal quantization type based on the tensor shape, layer, etc. if (!params->pure && ggml_is_quantized(default_type)) { @@ -16514,8 +16898,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } if (!params->ignore_imatrix_rules && !imatrix && (new_type == GGML_TYPE_IQ2_XXS || + new_type == GGML_TYPE_IQ2_XXS_R4 || new_type == GGML_TYPE_IQ2_XS || + new_type == GGML_TYPE_IQ2_XS_R4 || new_type == GGML_TYPE_IQ2_S || + new_type == GGML_TYPE_IQ2_S_R4|| new_type == GGML_TYPE_IQ1_S || (new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight")) || (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0))) { @@ -16544,6 +16931,102 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (new_type == GGML_TYPE_Q4_0_8_8) chunk_size_multiplier = 8; else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8) chunk_size_multiplier = 4; } + else if (new_type == GGML_TYPE_IQ4_NL_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ4_NL; + else chunk_size_multiplier = 4; + } + else if (new_type == GGML_TYPE_IQ4_XS_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ4_XS; + else chunk_size_multiplier = 4; + } + else if (new_type == GGML_TYPE_Q4_0_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q4_0; + else chunk_size_multiplier = 4; + } + else if (new_type == GGML_TYPE_Q5_0_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q5_0; + else chunk_size_multiplier = 4; + } + else if (new_type == GGML_TYPE_Q6_0_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q6_0; + else chunk_size_multiplier = 4; + } + else if (new_type == GGML_TYPE_Q8_0_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q8_0; + else chunk_size_multiplier = 4; + } + else if (new_type == GGML_TYPE_Q2_K_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q2_K; + else chunk_size_multiplier = 4; + } + else if (new_type == GGML_TYPE_Q3_K_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q3_K; + else chunk_size_multiplier = 4; + } + else if (new_type == GGML_TYPE_Q4_K_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q4_K; + else chunk_size_multiplier = 4; + } + else if (new_type == GGML_TYPE_Q5_K_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q5_K; + else chunk_size_multiplier = 4; + } + else if (new_type == GGML_TYPE_Q6_K_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q6_K; + else chunk_size_multiplier = 4; + } + else if (new_type == GGML_TYPE_Q8_K_R8) { + if (tensor->ne[1] % 8 != 0) new_type = GGML_TYPE_Q8_0; + else chunk_size_multiplier = 8; + } + else if (new_type == GGML_TYPE_IQ2_BN_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ2_BN; + else chunk_size_multiplier = 4; + } + else if (new_type == GGML_TYPE_IQ2_K_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ2_K; + else chunk_size_multiplier = 4; + } + else if (new_type == GGML_TYPE_IQ3_K_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ3_K; + else chunk_size_multiplier = 4; + } + else if (new_type == GGML_TYPE_IQ4_K_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ4_K; + else chunk_size_multiplier = 4; + } + else if (new_type == GGML_TYPE_IQ5_K_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ5_K; + else chunk_size_multiplier = 4; + } + else if (new_type == GGML_TYPE_IQ4_KS_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ4_KS; + else chunk_size_multiplier = 4; + } + else if (new_type == GGML_TYPE_IQ2_XXS_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ2_XXS; + else chunk_size_multiplier = 4; + } + else if (new_type == GGML_TYPE_IQ2_XS_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ2_XS; + else chunk_size_multiplier = 4; + } + else if (new_type == GGML_TYPE_IQ2_S_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ2_S; + else chunk_size_multiplier = 4; + } + else if (new_type == GGML_TYPE_IQ3_XXS_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ3_XXS; + else chunk_size_multiplier = 4; + } + else if (new_type == GGML_TYPE_IQ3_S_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_IQ3_S; + else chunk_size_multiplier = 4; + } + else if (new_type == GGML_TYPE_BF16_R16) { + if (tensor->ne[1] % 16 != 0) new_type = GGML_TYPE_BF16; + else chunk_size_multiplier = 16; + } LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type)); fflush(stdout); @@ -16827,6 +17310,7 @@ struct llama_model_params llama_model_default_params() { /*.use_mmap =*/ true, /*.use_mlock =*/ false, /*.check_tensors =*/ false, + /*.repack_tensors =*/ false, }; #ifdef GGML_USE_METAL @@ -18239,6 +18723,7 @@ struct llama_data_read { read_to(&n_seq_id, sizeof(n_seq_id)); if (n_seq_id != 0) { + llama_batch_free(batch); LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); return false; } @@ -19183,18 +19668,116 @@ int32_t llama_detokenize( // chat templates // -// Simple version of "llama_apply_chat_template" that only works with strings -// This function uses heuristic checks to determine commonly used template. It is not a jinja parser. +static llm_chat_template llama_chat_detect_template(const std::string & tmpl) { + if (auto it = LLM_CHAT_TEMPLATES.find(tmpl); it != LLM_CHAT_TEMPLATES.end()) { + return it->second; + } + auto tmpl_contains = [&tmpl](const char * haystack) -> bool { + return tmpl.find(haystack) != std::string::npos; + }; + if (tmpl_contains("<|im_start|>")) { + return LLM_CHAT_TEMPLATE_CHATML; + } else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) { + if (tmpl_contains("[SYSTEM_PROMPT]")) { + return LLM_CHAT_TEMPLATE_MISTRAL_V7; + } else if ( + // catches official 'v1' template + tmpl_contains("' [INST] ' + system_message") + // catches official 'v3' and 'v3-tekken' templates + || tmpl_contains("[AVAILABLE_TOOLS]") + ) { + // Official mistral 'v1', 'v3' and 'v3-tekken' templates + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md + if (tmpl_contains(" [INST]")) { + return LLM_CHAT_TEMPLATE_MISTRAL_V1; + } else if (tmpl_contains("\"[INST]\"")) { + return LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN; + } + return LLM_CHAT_TEMPLATE_MISTRAL_V3; + } else { + // llama2 template and its variants + // [variant] support system message + // See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + bool support_system_message = tmpl_contains("<>"); + bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]"); + bool strip_message = tmpl_contains("content.strip()"); + if (strip_message) { + return LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP; + } else if (add_bos_inside_history) { + return LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS; + } else if (support_system_message) { + return LLM_CHAT_TEMPLATE_LLAMA_2_SYS; + } else { + return LLM_CHAT_TEMPLATE_LLAMA_2; + } + } + } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) { + return LLM_CHAT_TEMPLATE_PHI_3; + } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) { + return LLM_CHAT_TEMPLATE_FALCON_3; + } else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) { + return LLM_CHAT_TEMPLATE_ZEPHYR; + } else if (tmpl_contains("bos_token + message['role']")) { + return LLM_CHAT_TEMPLATE_MONARCH; + } else if (tmpl_contains("")) { + return LLM_CHAT_TEMPLATE_GEMMA; + } else if (tmpl_contains("'\\n\\nAssistant: ' + eos_token")) { + // OrionStarAI/Orion-14B-Chat + return LLM_CHAT_TEMPLATE_ORION; + } else if (tmpl_contains("GPT4 Correct ")) { + // openchat/openchat-3.5-0106 + return LLM_CHAT_TEMPLATE_OPENCHAT; + } else if (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: ")) { + // eachadea/vicuna-13b-1.1 (and Orca variant) + if (tmpl_contains("SYSTEM: ")) { + return LLM_CHAT_TEMPLATE_VICUNA_ORCA; + } + return LLM_CHAT_TEMPLATE_VICUNA; + } else if (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>")) { + // deepseek-ai/deepseek-coder-33b-instruct + return LLM_CHAT_TEMPLATE_DEEPSEEK; + } else if (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>")) { + // CohereForAI/c4ai-command-r-plus + return LLM_CHAT_TEMPLATE_COMMAND_R; + } else if (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>")) { + return LLM_CHAT_TEMPLATE_LLAMA_3; + } else if (tmpl_contains("[gMASK]sop")) { + // chatglm3-6b + return LLM_CHAT_TEMPLATE_CHATGML_3; + } else if (tmpl_contains("[gMASK]")) { + return LLM_CHAT_TEMPLATE_CHATGML_4; + } else if (tmpl_contains(LU8("<用户>"))) { + // MiniCPM-3B-OpenHermes-2.5-v2-GGUF + return LLM_CHAT_TEMPLATE_MINICPM; + } else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) { + return LLM_CHAT_TEMPLATE_DEEPSEEK_2; + } else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) { + // original: if (tmpl_contains(LU8("'<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>'"))) { + return LLM_CHAT_TEMPLATE_DEEPSEEK_3; + } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) { + // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb + // EXAONE-3.0-7.8B-Instruct + return LLM_CHAT_TEMPLATE_EXAONE_3; + } else if (tmpl_contains("rwkv-world")) { + return LLM_CHAT_TEMPLATE_RWKV_WORLD; + } else if (tmpl_contains("<|start_of_role|>")) { + return LLM_CHAT_TEMPLATE_GRANITE; + } else if (tmpl_contains("message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1]")) { + return LLM_CHAT_TEMPLATE_GIGACHAT; + } else if (tmpl_contains("<|role_start|>")) { + return LLM_CHAT_TEMPLATE_MEGREZ; + } + return LLM_CHAT_TEMPLATE_UNKNOWN; +} + static int32_t llama_chat_apply_template_internal( - const std::string & tmpl, + const llm_chat_template tmpl, const std::vector & chat, std::string & dest, bool add_ass) { // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 std::stringstream ss; - auto tmpl_contains = [&tmpl](std::string haystack) -> bool { - return tmpl.find(haystack) != std::string::npos; - }; - if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) { + if (tmpl == LLM_CHAT_TEMPLATE_CHATML) { // chatml template for (auto message : chat) { ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n"; @@ -19202,16 +19785,59 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|im_start|>assistant\n"; } - } else if (tmpl == "llama2" || tmpl == "mistral" || tmpl_contains("[INST]")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) { + // Official mistral 'v7' template + // See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7 + for (auto message : chat) { + std::string role(message->role); + std::string content(message->content); + if (role == "system") { + ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]"; + } else if (role == "user") { + ss << "[INST] " << content << "[/INST]"; + } + else { + ss << " " << content << ""; + } + } + } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 + || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3 + || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN) { + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md + std::string leading_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 ? " " : ""; + std::string trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN ? "" : " "; + bool trim_assistant_message = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3; + bool is_inside_turn = false; + for (auto message : chat) { + if (!is_inside_turn) { + ss << leading_space << "[INST]" << trailing_space; + is_inside_turn = true; + } + std::string role(message->role); + std::string content(message->content); + if (role == "system") { + ss << content << "\n\n"; + } else if (role == "user") { + ss << content << leading_space << "[/INST]"; + } else { + ss << trailing_space << (trim_assistant_message ? trim(content) : content) << ""; + is_inside_turn = false; + } + } + } else if ( + tmpl == LLM_CHAT_TEMPLATE_LLAMA_2 + || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS + || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS + || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP) { // llama2 template and its variants // [variant] support system message - bool support_system_message = tmpl_contains("<>") || tmpl == "mistral"; - // [variant] space before + after response - bool space_around_response = tmpl_contains("' ' + eos_token"); + // See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + bool support_system_message = tmpl != LLM_CHAT_TEMPLATE_LLAMA_2; // [variant] add BOS inside history - bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]"); + bool add_bos_inside_history = tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS; // [variant] trim spaces from the input message - bool strip_message = tmpl_contains("content.strip()"); + bool strip_message = tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP; // construct the prompt bool is_inside_turn = true; // skip BOS at the beginning ss << "[INST] "; @@ -19232,12 +19858,11 @@ static int32_t llama_chat_apply_template_internal( } else if (role == "user") { ss << content << " [/INST]"; } else { - ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << ""; + ss << content << ""; is_inside_turn = false; } } - // llama2 templates seem to not care about "add_generation_prompt" - } else if (tmpl == "phi3" || (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>"))) { + } else if (tmpl == LLM_CHAT_TEMPLATE_PHI_3) { // Phi 3 for (auto message : chat) { std::string role(message->role); @@ -19246,7 +19871,16 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>\n"; } - } else if (tmpl == "zephyr" || tmpl_contains("<|user|>")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_FALCON_3) { + // Falcon 3 + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>\n" << message->content << "\n"; + } + if (add_ass) { + ss << "<|assistant|>\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_ZEPHYR) { // zephyr template for (auto message : chat) { ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n"; @@ -19254,7 +19888,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>\n"; } - } else if (tmpl == "monarch" || tmpl_contains("bos_token + message['role']")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_MONARCH) { // mlabonne/AlphaMonarch-7B template (the is included inside history) for (auto message : chat) { std::string bos = (message == chat.front()) ? "" : ""; // skip BOS for first message @@ -19263,7 +19897,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "assistant\n"; } - } else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl_contains("")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_GEMMA) { // google/gemma-7b-it std::string system_prompt = ""; for (auto message : chat) { @@ -19285,7 +19919,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "model\n"; } - } else if (tmpl == "orion" || tmpl_contains("'\\n\\nAssistant: ' + eos_token")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_ORION) { // OrionStarAI/Orion-14B-Chat std::string system_prompt = ""; for (auto message : chat) { @@ -19305,7 +19939,7 @@ static int32_t llama_chat_apply_template_internal( ss << message->content << ""; } } - } else if (tmpl == "openchat" || tmpl_contains("GPT4 Correct ")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_OPENCHAT) { // openchat/openchat-3.5-0106, for (auto message : chat) { std::string role(message->role); @@ -19319,13 +19953,13 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "GPT4 Correct Assistant:"; } - } else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: "))) { + } else if (tmpl == LLM_CHAT_TEMPLATE_VICUNA || tmpl == LLM_CHAT_TEMPLATE_VICUNA_ORCA) { // eachadea/vicuna-13b-1.1 (and Orca variant) for (auto message : chat) { std::string role(message->role); if (role == "system") { // Orca-Vicuna variant uses a system prefix - if (tmpl == "vicuna-orca" || tmpl_contains("SYSTEM: ")) { + if (tmpl == LLM_CHAT_TEMPLATE_VICUNA_ORCA) { ss << "SYSTEM: " << message->content << "\n"; } else { ss << message->content << "\n\n"; @@ -19339,7 +19973,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "ASSISTANT:"; } - } else if (tmpl == "deepseek" || (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>"))) { + } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK) { // deepseek-ai/deepseek-coder-33b-instruct for (auto message : chat) { std::string role(message->role); @@ -19354,7 +19988,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "### Response:\n"; } - } else if (tmpl == "command-r" || (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>"))) { + } else if (tmpl == LLM_CHAT_TEMPLATE_COMMAND_R) { // CohereForAI/c4ai-command-r-plus for (auto message : chat) { std::string role(message->role); @@ -19369,7 +20003,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"; } - } else if (tmpl == "llama3" || (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>"))) { + } else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA_3) { // Llama 3 for (auto message : chat) { std::string role(message->role); @@ -19378,7 +20012,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; } - } else if (tmpl == "chatglm3" || tmpl_contains("[gMASK]sop")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_3) { // chatglm3-6b ss << "[gMASK]" << "sop"; for (auto message : chat) { @@ -19388,7 +20022,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>"; } - } else if (tmpl == "chatglm4" || tmpl_contains("[gMASK]")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_4) { ss << "[gMASK]" << ""; for (auto message : chat) { std::string role(message->role); @@ -19397,7 +20031,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>"; } - } else if (tmpl == "minicpm" || tmpl_contains(LU8("<用户>"))) { + } else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) { // MiniCPM-3B-OpenHermes-2.5-v2-GGUF for (auto message : chat) { std::string role(message->role); @@ -19409,7 +20043,7 @@ static int32_t llama_chat_apply_template_internal( ss << trim(message->content); } } - } else if (tmpl == "deepseek2" || tmpl_contains("'Assistant: ' + message['content'] + eos_token")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_2) { // DeepSeek-V2 for (auto message : chat) { std::string role(message->role); @@ -19424,6 +20058,96 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "Assistant:"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_3) { + // DeepSeek-V3 + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content << "\n\n"; + } else if (role == "user") { + ss << LU8("<|User|>") << message->content; + } else if (role == "assistant") { + ss << LU8("<|Assistant|>") << message->content << LU8("<|end▁of▁sentence|>"); + } + } + if (add_ass) { + ss << LU8("<|Assistant|>"); + } + } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) { + // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb + // EXAONE-3.0-7.8B-Instruct + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n"; + } else if (role == "user") { + ss << "[|user|]" << trim(message->content) << "\n"; + } else if (role == "assistant") { + ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n"; + } + } + if (add_ass) { + ss << "[|assistant|]"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) { + // this template requires the model to have "\n\n" as EOT token + for (auto message : chat) { + std::string role(message->role); + if (role == "user") { + ss << "User: " << message->content << "\n\nAssistant:"; + } else { + ss << message->content << "\n\n"; + } + } + } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) { + // IBM Granite template + for (const auto & message : chat) { + std::string role(message->role); + ss << "<|start_of_role|>" << role << "<|end_of_role|>"; + if (role == "assistant_tool_call") { + ss << "<|tool_call|>"; + } + ss << message->content << "<|end_of_text|>\n"; + } + if (add_ass) { + ss << "<|start_of_role|>assistant<|end_of_role|>\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) { + // GigaChat template + bool has_system = !chat.empty() && std::string(chat[0]->role) == "system"; + + // Handle system message if present + if (has_system) { + ss << "" << chat[0]->content << "<|message_sep|>"; + } else { + ss << ""; + } + + // Process remaining messages + for (size_t i = has_system ? 1 : 0; i < chat.size(); i++) { + std::string role(chat[i]->role); + if (role == "user") { + ss << "user<|role_sep|>" << chat[i]->content << "<|message_sep|>" + << "available functions<|role_sep|>[]<|message_sep|>"; + } else if (role == "assistant") { + ss << "assistant<|role_sep|>" << chat[i]->content << "<|message_sep|>"; + } + } + + // Add generation prompt if needed + if (add_ass) { + ss << "assistant<|role_sep|>"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_MEGREZ) { + // Megrez template + for (auto message : chat) { + std::string role(message->role); + ss << "<|role_start|>" << role << "<|role_end|>" << message->content << "<|turn_end|>"; + } + + if (add_ass) { + ss << "<|role_start|>assistant<|role_end|>"; + } } else { // template not supported return -1; @@ -19443,15 +20167,15 @@ int32_t llama_chat_apply_template( std::string curr_tmpl(tmpl == nullptr ? "" : tmpl); if (tmpl == nullptr) { GGML_ASSERT(model != nullptr); - // load template from model - std::vector model_template(2048, 0); // longest known template is about 1200 bytes - std::string template_key = "tokenizer.chat_template"; - int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); - if (res < 0) { + + // load template from model, if available + const auto & it = model->gguf_kv.find("tokenizer.chat_template"); + if (it != model->gguf_kv.end() && it->second.size() > 0) { + curr_tmpl = it->second; + } + else { // worst case: there is no information about template, we will use chatml by default - curr_tmpl = "chatml"; // see llama_chat_apply_template_internal - } else { - curr_tmpl = std::string(model_template.data(), model_template.size()); + curr_tmpl = "chatml"; // see llama_chat_apply_template_internal } } @@ -19463,7 +20187,11 @@ int32_t llama_chat_apply_template( } std::string formatted_chat; - int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass); + llm_chat_template detected_tmpl = llama_chat_detect_template(curr_tmpl); + if (detected_tmpl == LLM_CHAT_TEMPLATE_UNKNOWN) { + return -1; + } + int32_t res = llama_chat_apply_template_internal(detected_tmpl, chat_vec, formatted_chat, add_ass); if (res < 0) { return res; } @@ -19473,6 +20201,14 @@ int32_t llama_chat_apply_template( return res; } +int32_t llama_chat_builtin_templates(const char ** output, size_t len) { + auto it = LLM_CHAT_TEMPLATES.begin(); + for (size_t i = 0; i < std::min(len, LLM_CHAT_TEMPLATES.size()); i++) { + output[i] = it->first.c_str(); + std::advance(it, 1); + } + return (int32_t) LLM_CHAT_TEMPLATES.size(); +} // // grammar // diff --git a/src/unicode.cpp b/src/unicode.cpp index 46650bff..cfffde0d 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -648,18 +648,25 @@ std::vector unicode_regex_split(const std::string & text, const std { "\\p{N}", codepoint_flags::NUMBER }, { "\\p{L}", codepoint_flags::LETTER }, { "\\p{P}", codepoint_flags::PUNCTUATION }, + { "\\p{M}", codepoint_flags::ACCENT_MARK }, + { "\\p{S}", codepoint_flags::SYMBOL }, }; static const std::map k_ucat_cpt = { { codepoint_flags::NUMBER, 0xD1 }, { codepoint_flags::LETTER, 0xD2 }, { codepoint_flags::PUNCTUATION, 0xD3 }, + { codepoint_flags::ACCENT_MARK, 0xD4 }, + { codepoint_flags::SYMBOL, 0xD5 }, + }; static const std::map k_ucat_map = { { codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9 { codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z - { codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\} + { codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}i + { codepoint_flags::ACCENT_MARK, "" }, // no sub-128 codepoints + { codepoint_flags::SYMBOL, "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`| }; // compute collapsed codepoints only if needed by at least one regex