diff --git a/common/common.cpp b/common/common.cpp index b4be2d31..30921e24 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1442,6 +1442,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.merge_qkv = true; return true; } + if (arg == "-muge" || arg == "--merge-up-gate-expsrts") { + params.merge_up_gate_exps = true; + return true; + } if (arg == "-khad" || arg == "--k-cache-hadamard") { params.k_cache_hadamard = true; return true; @@ -2148,6 +2152,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-no-gr, --no-graph-reuse", "disable graph reuse (default: %s)", !params.graph_reuse ? "enabled" : "disabled" }); options.push_back({ "*", "-ser, --smart-expert-reduction", "experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts}); options.push_back({ "*", "-mqkv, --merge-qkv,", "merge Q,K,V (default: %d)", params.merge_qkv}); + options.push_back({ "*", "-muge, --merge-up-gate-experts,","merge ffn_up/gate_exps (default: %d)", params.merge_up_gate_exps}); options.push_back({ "*", "-khad, --k-cache-hadamard,", "Use Hadamard transform for K-cache (default: %d)", params.k_cache_hadamard}); options.push_back({ "*", "-smf16, --split-mode-f16,", "Use f16 for data exchange between GPUs (default: %d)", params.split_mode_f16}); options.push_back({ "*", "-smf32, --split-mode-f32,", "Use f32 for data exchange between GPUs (default: %d)", !params.split_mode_f16}); @@ -3088,6 +3093,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & mparams.use_thp = params.use_thp; mparams.validate_quants = params.validate_quants; mparams.merge_qkv = params.merge_qkv; + mparams.merge_up_gate_exps = params.merge_up_gate_exps; if (params.kv_overrides.empty()) { mparams.kv_overrides = NULL; } else { @@ -4134,6 +4140,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "use_thp: %s # default: false\n", params.use_thp ? "true" : "false"); fprintf(stream, "validate_quants: %s # default: false\n", params.validate_quants ? "true" : "false"); fprintf(stream, "merge_qkv: %s # default: false\n", params.merge_qkv ? "true" : "false"); + fprintf(stream, "merge_up_gate_exps: %s # default: false\n", params.merge_up_gate_exps ? "true" : "false"); fprintf(stream, "max_extra_alloc: %d # default: 256\n", params.max_extra_alloc_MiB); fprintf(stream, "penalize_nl: %s # default: false\n", sparams.penalize_nl ? "true" : "false"); fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); diff --git a/common/common.h b/common/common.h index 79d12773..9addabe3 100644 --- a/common/common.h +++ b/common/common.h @@ -287,6 +287,7 @@ struct gpt_params { bool validate_quants = false; // if true, check for NaNs while loading the model bool only_active_exps = true; // if true, offload only active experts (relevant only for hybrid CPU/GPU) bool merge_qkv = false; // if true, merge separate Q, K, V tensors into a single, contiguous tensor + bool merge_up_gate_exps= false; // if true, merge ffn_up_exps and ffn_gate_exps into a single, contiguous tensor bool k_cache_hadamard = false; // if true, use Hadamard transform for the K-cache (only makes sense with quantized cache) bool split_mode_graph_scheduling = false; // if true, force split mode graph scheduling bool split_mode_f16 = true; // if true, intermediate results will be cast to f16 before copying to other GPUs to perform reduce ops diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 2c450863..0672d1b2 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -268,6 +268,7 @@ struct cmd_params { bool use_thp = false; bool no_ooae = false; bool mqkv = false; + bool muge = false; bool rcache = false; output_formats output_format; output_formats output_format_stderr; @@ -293,7 +294,7 @@ static const cmd_params cmd_params_defaults = { /* mla_attn */ {3}, /* attn_max_batch */ {0}, /* ser */ {{-1,0.0f}}, - /* reuse */ {false}, + /* reuse */ {true}, /* tensor_split */ {std::vector(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, @@ -310,6 +311,7 @@ static const cmd_params cmd_params_defaults = { /* use_thp */ false, /* no_ooae */ false, /* mqkv */ false, + /* muge */ false, /* rcache */ false, /* output_format */ MARKDOWN, /* output_format_stderr */ NONE, @@ -354,6 +356,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -rtr, --run-time-repack <0|1> (default: %s)\n", cmd_params_defaults.repack ? "1" : "0"); printf(" -cuda, --cuda-params (default: %s)\n", cmd_params_defaults.cuda_params.c_str()); printf(" -mqkv, --merge-qkv (default: %s)\n", cmd_params_defaults.mqkv ? "1" : "0"); + printf(" -muge, --merge-up-gate-experts (default: %s)\n", cmd_params_defaults.muge ? "1" : "0"); printf(" -rcache, --rope-cache (default: %s)\n", cmd_params_defaults.rcache ? "1" : "0"); printf(" -thp, --transparent-huge-pages <0|1> (default: %s)\n", cmd_params_defaults.use_thp? "1" : "0"); printf(" -ot, --override-tensor pattern (default: none)\n"); @@ -789,6 +792,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } params.mqkv = std::stoi(argv[i]); + } else if (arg == "-muge" || arg == "--merge-up-gate-exps") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.muge = std::stoi(argv[i]); } else if (arg == "-rcache" || arg == "--rope-cache") { if (++i >= argc) { invalid_param = true; @@ -926,6 +935,7 @@ struct cmd_params_instance { bool use_thp = false; bool no_ooae = false; bool mqkv = false; + bool muge = false; bool rcache = false; const llama_model_tensor_buft_override* buft_overrides; @@ -943,6 +953,7 @@ struct cmd_params_instance { mparams.repack_tensors = repack; mparams.use_thp = use_thp; mparams.merge_qkv = mqkv; + mparams.merge_up_gate_exps = muge; mparams.tensor_buft_overrides = buft_overrides; mparams.mla = mla_attn; @@ -958,6 +969,7 @@ struct cmd_params_instance { use_mmap == other.use_mmap && repack == other.repack && mqkv == other.mqkv && + muge == other.muge && use_thp == other.use_thp && tensor_split == other.tensor_split; } @@ -1047,6 +1059,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .use_thp = */ params.use_thp, /* .no_ooae = */ params.no_ooae, /* .mqkv = */ params.mqkv, + /* .muge = */ params.muge, /* .rcache = */ params.rcache, /* .buft_overrides=*/ params.buft_overrides.data(), }; @@ -1088,6 +1101,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .use_thp = */ params.use_thp, /* .no_ooae = */ params.no_ooae, /* .mqkv = */ params.mqkv, + /* .muge = */ params.muge, /* .rcache = */ params.rcache, /* .buft_overrides=*/ params.buft_overrides.data(), }; @@ -1129,6 +1143,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .use_thp = */ params.use_thp, /* .no_ooae = */ params.no_ooae, /* .mqkv = */ params.mqkv, + /* .muge = */ params.muge, /* .rcache = */ params.rcache, /* .buft_overrides=*/ params.buft_overrides.data(), }; @@ -1170,6 +1185,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .use_thp = */ params.use_thp, /* .no_ooae = */ params.no_ooae, /* .mqkv = */ params.mqkv, + /* .muge = */ params.muge, /* .rcache = */ params.rcache, /* .buft_overrides=*/ params.buft_overrides.data(), }; @@ -1222,6 +1238,7 @@ struct test { bool use_thp = false; bool no_ooae = false; bool mqkv = false; + bool muge = false; bool rcache = false; std::string override_tensor; int n_prompt; @@ -1259,6 +1276,7 @@ struct test { embeddings = inst.embeddings; repack = inst.repack; mqkv = inst.mqkv; + muge = inst.muge; fmoe = inst.fmoe; ger = inst.ger; rcache = inst.rcache; @@ -1368,7 +1386,7 @@ struct test { "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser", "reuse", - "tensor_split", "use_mmap", "embeddings", "repack", "mqkv", "fused_moe", "grouped_er", + "tensor_split", "use_mmap", "embeddings", "repack", "mqkv", "muge", "fused_moe", "grouped_er", "no_fused_up_gate", "use_thp", "no_ooae", "rcache", "cuda_params", "override_tensor", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", @@ -1392,7 +1410,7 @@ struct test { field == "gpu_blas" || field == "blas" || field == "sycl" || field == "no_kv_offload" || field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || field == "use_thp" || field == "fused_moe" || field == "grouped_er" || field == "no_fused_up_gate" || field == "no_ooae" || field == "mqkv" || - field == "rcache" || field == "reuse") { + field == "rcache" || field == "reuse" || field == "muge") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -1435,7 +1453,7 @@ struct test { std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser), std::to_string(reuse), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), - std::to_string(repack), std::to_string(mqkv), std::to_string(fmoe), std::to_string(ger), + std::to_string(repack), std::to_string(mqkv), std::to_string(muge), std::to_string(fmoe), std::to_string(ger), std::to_string(no_fug), std::to_string(use_thp), std::to_string(no_ooae), std::to_string(rcache), cuda_params, override_tensor, std::to_string(n_prompt), std::to_string(n_gen), test_time, @@ -1621,6 +1639,9 @@ struct markdown_printer : public printer { if (field == "mqkv") { return 4; } + if (field == "muge") { + return 4; + } if (field == "use_thp") { return 3; } @@ -1688,6 +1709,9 @@ struct markdown_printer : public printer { if (field == "mqkv") { return "mqkv"; } + if (field == "muge") { + return "muge"; + } if (field == "use_thp") { return "thp"; } @@ -1791,6 +1815,9 @@ struct markdown_printer : public printer { if (params.mqkv != cmd_params_defaults.mqkv) { fields.emplace_back("mqkv"); } + if (params.muge != cmd_params_defaults.muge) { + fields.emplace_back("muge"); + } if (params.use_thp != cmd_params_defaults.use_thp) { fields.emplace_back("use_thp"); } diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 23300a8d..db8e3520 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -1,4 +1,5 @@ // + // Copyright (C) 2023-2024 The ggml authors // Copyright (C) 2024 Iwan Kawrakow // MIT license @@ -2487,23 +2488,21 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten if (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1 && ggml_is_quantized(src0_1->type) && - ggml_is_quantized(src0_2->type) && + (!src0_2 || ggml_is_quantized(src0_2->type)) && ggml_backend_buffer_is_cuda(src0_1->buffer) && - ggml_backend_buffer_is_cuda(src0_2->buffer) && + (!src0_2 || ggml_backend_buffer_is_cuda(src0_2->buffer)) && ggml_backend_buffer_is_cuda(src1->buffer) && ggml_backend_buffer_is_cuda(dst->buffer) && src1->type == GGML_TYPE_F32) { int device_id = ctx.device; ggml_backend_cuda_buffer_context * src0_1_ctx = (ggml_backend_cuda_buffer_context *) src0_1->buffer->context; - ggml_backend_cuda_buffer_context * src0_2_ctx = (ggml_backend_cuda_buffer_context *) src0_2->buffer->context; + ggml_backend_cuda_buffer_context * src0_2_ctx = src0_2 ? (ggml_backend_cuda_buffer_context *) src0_2->buffer->context : nullptr; ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context; ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context; if (src0_1_ctx->device == device_id && - src0_2_ctx->device == device_id && + (!src0_2_ctx || src0_2_ctx->device == device_id) && src1_ctx->device == device_id && dst_ctx->device == device_id) { - //printf("%s(%s, %s): %ld x %ld x %ld, %ld x %ld x %ld, %ld x %ld x %ld\n", __func__, src0_1->name, src0_2->name, - // src0->ne[0], src0->ne[1], src0->ne[2], src1->ne[0], src1->ne[1], src1->ne[2], ids->ne[0], ids->ne[1], ids->ne[2]); // Fast TG path const int64_t n_ids = ids->ne[0]; auto stream = ctx.stream(device_id, 0); @@ -2518,7 +2517,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten const int64_t src1_padded_col_size = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING); ggml_cuda_pool_alloc src1_quantized(ctx.pool()); - if (ggml_is_quantized(src0_1->type) || ggml_is_quantized(src0_2->type)) { + if (ggml_is_quantized(src0_1->type) || (src0_2 && ggml_is_quantized(src0_2->type))) { GGML_ASSERT(src1->ne[0] % QK8_1 == 0); auto src_1_ddq_size = src1_padded_col_size*sizeof(block_q8_1)/QK8_1; local_src1.data = src1_quantized.alloc(src_1_ddq_size); @@ -2538,10 +2537,36 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten ((ggml_backend_cuda_buffer_context *)next->buffer->context)->device == device_id; auto unary_op = (ggml_unary_op)dst->op_params[0]; - ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, src0_1, &local_src1, ids, &local_dst, - dst->src[4], dst->src[5], - (const char *)src0_1->data, (const char *)src0_2->data, (const float *)src1->data, src1_quantized.get(), - (float *)local_dst.data, 0, src0_1->ne[1], 1, src1_padded_col_size, unary_op, stream); + if (src0_2) { + ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, src0_1, &local_src1, ids, &local_dst, + dst->src[4], dst->src[5], + (const char *)src0_1->data, src0_2 ? (const char *)src0_2->data : nullptr, + (const float *)src1->data, src1_quantized.get(), + (float *)local_dst.data, 0, src0_1->ne[1], 1, src1_padded_col_size, unary_op, stream); + } else { + auto local_src0_1 = *src0_1; + local_src0_1.ne[1] /= 2; + auto local_src0_2 = local_src0_1; + local_src0_2.data = (char *)local_src0_1.data + local_src0_1.ne[1]*local_src0_1.nb[1]; + if (!dst->src[4]) { + ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, &local_src0_1, &local_src1, ids, &local_dst, + nullptr, nullptr, + (const char *)local_src0_1.data, (const char *)local_src0_2.data, + (const float *)src1->data, src1_quantized.get(), + (float *)local_dst.data, 0, local_src0_1.ne[1], 1, src1_padded_col_size, unary_op, stream); + } else { + GGML_ASSERT(!dst->src[5]); + auto local_bias_1 = *dst->src[4]; + local_bias_1.ne[0] /= 2; + auto local_bias_2 = local_bias_1; + local_bias_2.data = (char *)local_bias_1.data + local_bias_1.ne[0]*local_bias_1.nb[0]; + ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, &local_src0_1, &local_src1, ids, &local_dst, + &local_bias_1, &local_bias_2, + (const char *)local_src0_1.data, (const char *)local_src0_2.data, + (const float *)src1->data, src1_quantized.get(), + (float *)local_dst.data, 0, local_src0_1.ne[1], 1, src1_padded_col_size, unary_op, stream); + } + } CUDA_CHECK(cudaGetLastError()); if (!fuse_next) return i; @@ -2608,7 +2633,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten // looks like it really depends just on the total number of experts. // TODO: verify with more models, or perhaps make the magic constant '32' to be defined via a compile time define. if (src1->ne[2] <= ctx.mmq_id_thresh*src0->ne[2] && - ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1 && + ggml_is_quantized(src0_1->type) && (!src0_2 || src0_1->type == src0_2->type) && src1->ne[1] == 1 && src1->ne[3] == 1 && ggml_cuda_can_use_mmq_id(src0_1->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { const int64_t ne_get_rows = ne12 * n_ids; @@ -2631,6 +2656,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten src0_1->type, ne10, src1->nb[1] / ts_src1, src1->nb[2] / ts_src1, src1->nb[2] / ts_src1, ne10_padded, ne11_flat, 1, 1, stream); + if (src0_2) { ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); @@ -2662,6 +2688,34 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data); } + } else { + + ggml_cuda_pool_alloc dst_up_gate_contiguous(ctx.pool(), 2*sizeof(float)*ggml_nelements(dst)); + ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + dst_row.ne[0] *= 2; + dst_row.nb[1] *= 2; + dst_row.nb[2] *= 2; + dst_row.nb[3] *= 2; + dst_row.data = dst_up_gate_contiguous.get(); + ggml_cuda_mul_mat_q_id(ctx, src0_1, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get()); + if (dst->src[4]) { + GGML_ASSERT(!dst->src[5]); + ggml_cuda_add_id((const float *)dst_row.data, (const float *)dst->src[4]->data, (const int32_t *)ids->data, + (float *)dst_row.data, dst_row.ne[0], dst_row.ne[1], dst_row.ne[2], dst_row.ne[0], dst_row.ne[1], + dst_row.nb[1], dst_row.nb[2], dst->src[4]->nb[1], ids->nb[1], stream); + CUDA_CHECK(cudaGetLastError()); + } + + auto unary_op = (ggml_unary_op)dst->op_params[0]; + if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { + ggml_swiglu_oai_cuda_f32((const float *)dst_up_gate_contiguous.get() + dst->ne[0], (const float *)dst_up_gate_contiguous.get(), + (float *)dst->data, ggml_nelements(dst), dst->ne[0], src0_1->ne[1], src0_1->ne[1], + 1.702f, 7.0f, stream); + } else { + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst), dst->ne[0], + (const float *)dst_up_gate_contiguous.get(), (float *)dst->data); + } + } CUDA_CHECK(cudaGetLastError()); if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) && @@ -2680,22 +2734,24 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten CUDA_CHECK(cudaStreamSynchronize(stream)); ggml_tensor src0_1_row = *src0_1; - ggml_tensor src0_2_row = *src0_2; + ggml_tensor src0_2_row; if (src0_2) src0_2_row = *src0_2; ggml_tensor src1_row = *src1; ggml_tensor final_dst; ggml_tensor final_src; char * src0_1_original = (char *) src0_1->data; - char * src0_2_original = (char *) src0_2->data; + char * src0_2_original = src0_2 ? (char *) src0_2->data : nullptr; char * src1_original = (char *) src1->data; char * dst_original = (char *) dst->data; src0_1_row.ne[2] = 1; src0_1_row.ne[3] = 1; src0_1_row.nb[3] = nb02; - src0_2_row.ne[2] = 1; - src0_2_row.ne[3] = 1; - src0_2_row.nb[3] = nb02; + if (src0_2) { + src0_2_row.ne[2] = 1; + src0_2_row.ne[3] = 1; + src0_2_row.nb[3] = nb02; + } src1_row.ne[1] = 1; src1_row.ne[2] = 1; @@ -2723,7 +2779,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten ggml_cuda_pool_alloc src1_quantized(ctx.pool()); bool use_quantized_src1 = false; int64_t src1_padded_num_cols = 0, src1_padded_row_size = 0, src1_quantized_size = 0; - if (ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1) { + if (ggml_is_quantized(src0_1->type) && (!src0_2 || src0_1->type == src0_2->type) && src1->ne[1] == 1 && src1->ne[3] == 1) { if (ggml_cuda_should_use_mmq(src0_1->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { src1_padded_num_cols = GGML_PAD(src1->ne[0], MATRIX_ROW_PADDING); src1_padded_row_size = src1_padded_num_cols/ggml_blck_size(GGML_TYPE_Q8_1)*ggml_type_size(GGML_TYPE_Q8_1); @@ -2736,8 +2792,14 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten if (!use_quantized_src1) { src1_contiguous.alloc(sizeof(float)*ggml_nelements(src1)); } - ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); - ggml_cuda_pool_alloc dst_gate_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + ggml_cuda_pool_alloc dst_up_contiguous(ctx.pool()), dst_gate_contiguous(ctx.pool()); + if (src0_2) { + dst_up_contiguous.alloc(sizeof(float)*ggml_nelements(dst)); + dst_gate_contiguous.alloc(sizeof(float)*ggml_nelements(dst)); + } else { + dst_up_contiguous.alloc(2*sizeof(float)*ggml_nelements(dst)); + dst_gate_contiguous.alloc(sizeof(float)*ggml_nelements(dst)); + } ggml_cuda_pool_alloc final_dst_contiguous(ctx.pool()); if (fuse_down) { final_dst.data = final_dst_contiguous.alloc(ggml_nelements(next)); @@ -2780,20 +2842,26 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten } src0_1_row.data = src0_1_original + i02*nb02; - src0_2_row.data = src0_2_original + i02*nb02; + if (src0_2_original) src0_2_row.data = src0_2_original + i02*nb02; GGML_ASSERT(nb11 == sizeof(float)*ne10); GGML_ASSERT(nb1 == sizeof(float)*ne0); + auto nb1l = nb1; + if (!src0_2) { + nb1l = nb1*2; + dst_row.ne[0] = dst->ne[0] * 2; + } + src1_row.ne[1] = num_src1_rows; src1_row.nb[1] = use_quantized_src1 ? src1_padded_row_size : nb11; src1_row.nb[2] = num_src1_rows*src1_row.nb[1]; src1_row.nb[3] = num_src1_rows*src1_row.nb[1]; dst_row.ne[1] = num_src1_rows; - dst_row.nb[1] = nb1; - dst_row.nb[2] = num_src1_rows*nb1; - dst_row.nb[3] = num_src1_rows*nb1; + dst_row.nb[1] = nb1l; + dst_row.nb[2] = num_src1_rows*nb1l; + dst_row.nb[3] = num_src1_rows*nb1l; dst_row.data = dst_up_contiguous.get(); if (use_quantized_src1) { @@ -2804,6 +2872,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten CUDA_CHECK(cudaGetLastError()); if (dst->src[4]) { + GGML_ASSERT(dst_row.ne[0] == dst->src[4]->ne[0]); dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u)); dim3 grid_dims(num_src1_rows); k_quick_add<<>>(dst_row.ne[0], (const float *)dst_row.data, @@ -2811,31 +2880,46 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten CUDA_CHECK(cudaGetLastError()); } - dst_row.data = dst_gate_contiguous.get(); - if (use_quantized_src1) { - ggml_cuda_mul_mat_q_id(ctx, &src0_2_row, &src1_row, nullptr, &dst_row, nullptr, src1_quantized.get()); - } else { - ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row, nullptr, 0); - } - CUDA_CHECK(cudaGetLastError()); - - if (dst->src[5]) { - dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u)); - dim3 grid_dims(num_src1_rows); - k_quick_add<<>>(dst_row.ne[0], (const float *)dst_row.data, - (const float *)((const char *)dst->src[5]->data + i02*dst->src[5]->nb[1]), (float *)dst_row.data); - CUDA_CHECK(cudaGetLastError()); - } - auto unary_op = (ggml_unary_op)dst->op_params[0]; - if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { - ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + if (src0_2) { + dst_row.data = dst_gate_contiguous.get(); + if (use_quantized_src1) { + ggml_cuda_mul_mat_q_id(ctx, &src0_2_row, &src1_row, nullptr, &dst_row, nullptr, src1_quantized.get()); + } else { + ggml_cuda_mul_mat(ctx, &src0_2_row, &src1_row, &dst_row, nullptr, 0); + } + CUDA_CHECK(cudaGetLastError()); + + if (dst->src[5]) { + dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u)); + dim3 grid_dims(num_src1_rows); + k_quick_add<<>>(dst_row.ne[0], (const float *)dst_row.data, + (const float *)((const char *)dst->src[5]->data + i02*dst->src[5]->nb[1]), (float *)dst_row.data); + CUDA_CHECK(cudaGetLastError()); + } + if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { + ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get(), ggml_nelements(&dst_row), dst_row.ne[0], dst_row.ne[0], dst_row.ne[0], 1.702f, 7.0f, stream); + } else { + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), + (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst_gate_contiguous.get()); + } } else { - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), - (float *)dst_gate_contiguous.get()); + if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { + ggml_swiglu_oai_cuda_f32((const float *)dst_up_contiguous.get() + dst->ne[0], (const float *)dst_up_contiguous.get(), + (float *)dst_gate_contiguous.get(), ggml_nelements(&dst_row)/2, dst->ne[0], src0_1->ne[1], src0_1->ne[1], + 1.702f, 7.0f, stream); + } else { + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row)/2, dst->ne[0], + (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get()); + } + dst_row.data = dst_gate_contiguous.get(); + dst_row.ne[0] /= 2; + dst_row.nb[1] /= 2; + dst_row.nb[2] /= 2; + dst_row.nb[3] /= 2; } CUDA_CHECK(cudaGetLastError()); @@ -3603,7 +3687,7 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud auto src0_2 = node->src[1]; auto src1 = node->src[2]; if (src1->ne[1] != 1 || src1->ne[2] != 1 || src1->ne[3] != 1 || src1->type != GGML_TYPE_F32 || - !ggml_is_quantized(src0_1->type) || !ggml_is_quantized(src0_2->type)) { + !ggml_is_quantized(src0_1->type) || (src0_2 && !ggml_is_quantized(src0_2->type))) { use_cuda_graph = false; } else { if (i < cgraph->n_nodes-1) { @@ -3967,8 +4051,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons bool is_fused_up_gate = op->op == GGML_OP_MOE_FUSED_UP_GATE || op->op == GGML_OP_FUSED_UP_GATE; struct ggml_tensor * a = op->src[0]; struct ggml_tensor * b = is_fused_up_gate ? op->src[2] : op->src[1]; - if (is_fused_up_gate && a->type != op->src[1]->type) { - printf("%s: returning false for GGML_OP_MOE_FUSED_UP_GATE because src0->type != src1->type\n", __func__); + if (is_fused_up_gate && op->src[1] && a->type != op->src[1]->type) { + fprintf(stderr, "%s: returning false for GGML_OP_MOE_FUSED_UP_GATE because src0->type != src1->type\n", __func__); return false; } //================================================================== diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 5e105474..2152210e 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -61,6 +61,18 @@ static __global__ void fused_mul_silu_f32(const float * x, const float * y, floa dst[i] = x[i] * y[i] / (1.0f + expf(-x[i])); } +static __global__ void fused_mul_silu_f32(const float * x, float * dst, const int k, const int ne0) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + int row = i / ne0; + int j = i % ne0; + auto x_row = x + 2*row*ne0; + dst[i] = x_row[j] * x_row[j + ne0] / (1.0f + expf(-x_row[j + ne0])); +} + static __global__ void fused_mul_relu_f32(const float * x, const float * y, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -70,6 +82,18 @@ static __global__ void fused_mul_relu_f32(const float * x, const float * y, floa dst[i] = fmaxf(x[i], 0) * y[i]; } +static __global__ void fused_mul_relu_f32(const float * x, float * dst, const int k, const int ne0) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + int row = i / ne0; + int j = i % ne0; + auto x_row = x + 2*row*ne0; + dst[i] = fmaxf(x_row[j + ne0], 0) * x_row[j]; +} + static __global__ void fused_mul_gelu_f32(const float * x, const float * y, float * dst, const int k) { constexpr float GELU_COEF_A = 0.044715f; constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; @@ -82,6 +106,21 @@ static __global__ void fused_mul_gelu_f32(const float * x, const float * y, floa dst[i] = 0.5f*xi*y[i]*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi))); } +static __global__ void fused_mul_gelu_f32(const float * x, float * dst, const int k, const int ne0) { + constexpr float GELU_COEF_A = 0.044715f; + constexpr float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + int row = i / ne0; + int j = i % ne0; + auto x_row = x + 2*row*ne0; + float xi = x_row[j + ne0]; + dst[i] = 0.5f*xi*x_row[j]*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi))); +} + static __global__ void tanh_f32(const float * x, float * dst, int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -199,6 +238,21 @@ static void fused_mul_gelu_f32_cuda(const float * x, const float * y, float * ds fused_mul_gelu_f32<<>>(x, y, dst, k); } +static void fused_mul_silu_f32_cuda(const float * x, float * dst, const int k, const int ne0, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; + fused_mul_silu_f32<<>>(x, dst, k, ne0); +} + +static void fused_mul_relu_f32_cuda(const float * x, float * dst, const int k, const int ne0, cudaStream_t stream) { + const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; + fused_mul_relu_f32<<>>(x, dst, k, ne0); +} + +static void fused_mul_gelu_f32_cuda(const float * x, float * dst, const int k, const int ne0, cudaStream_t stream) { + const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; + fused_mul_gelu_f32<<>>(x, dst, k, ne0); +} + static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE; tanh_f32<<>>(x, dst, k); @@ -302,29 +356,33 @@ void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op, } } +void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op, + int64_t nelements, int64_t ne0, const float * src0_d, float * dst_d) { + + cudaStream_t stream = ctx.stream(); + + switch (op) { + case GGML_UNARY_OP_SILU: fused_mul_silu_f32_cuda(src0_d, dst_d, nelements, ne0, stream); break; + case GGML_UNARY_OP_RELU: fused_mul_relu_f32_cuda(src0_d, dst_d, nelements, ne0, stream); break; + case GGML_UNARY_OP_GELU: fused_mul_gelu_f32_cuda(src0_d, dst_d, nelements, ne0, stream); break; + default: GGML_ASSERT(false); + } +} + void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_are_same_shape(src0, src1)); - ggml_unary_op op = (ggml_unary_op)dst->op_params[0]; + GGML_ASSERT(ggml_is_contiguous(src0)); - ggml_fused_mul_unary(ctx, op, ggml_nelements(dst), (const float *)src0->data, (const float *)src1->data, (float *)dst->data); - - //cudaStream_t stream = ctx.stream(); - - //const float * src0_d = (const float *)src0->data; - //const float * src1_d = (const float *)src1->data; - //float * dst_d = (float *)dst->data; - - //switch (op) { - // case GGML_UNARY_OP_SILU: fused_mul_silu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; - // case GGML_UNARY_OP_RELU: fused_mul_relu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; - // case GGML_UNARY_OP_GELU: fused_mul_gelu_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), stream); break; - // default: GGML_ASSERT(false); - //} + if (src1) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + ggml_fused_mul_unary(ctx, op, ggml_nelements(dst), (const float *)src0->data, (const float *)src1->data, (float *)dst->data); + } else { + GGML_ASSERT(src0->ne[0] == 2*dst->ne[0] && src0->ne[1] == dst->ne[1] && src0->ne[2] == dst->ne[2] && src0->ne[3] == dst->ne[3]); + ggml_fused_mul_unary(ctx, op, ggml_nelements(dst), dst->ne[0], (const float *)src0->data, (float *)dst->data); + } } void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index f47a5cc7..42505344 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -89,4 +89,7 @@ void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op, int64_t nelements, const float * x, const float * y, float * z); +void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op, + int64_t nelements,int64_t ne0, const float * x, float * z); + void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 3d944bf6..c849e026 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7628,13 +7628,13 @@ struct ggml_tensor * ggml_moe_up_gate( struct ggml_tensor * b, struct ggml_tensor * ids, enum ggml_unary_op op) { - if (as_up->type != as_gate->type || !ggml_are_same_shape(as_up, as_gate)) { + if (as_gate && (as_up->type != as_gate->type || !ggml_are_same_shape(as_up, as_gate))) { struct ggml_tensor * result_up = ggml_mul_mat_id(ctx, as_up, b, ids); struct ggml_tensor * result_gate = ggml_mul_mat_id(ctx, as_gate, b, ids); return ggml_fused_mul_unary(ctx, result_gate, result_up, op); } GGML_ASSERT(!ggml_is_transposed(as_up)); - GGML_ASSERT(!ggml_is_transposed(as_gate)); + GGML_ASSERT(!as_gate || !ggml_is_transposed(as_gate)); GGML_ASSERT(ids->type == GGML_TYPE_I32); GGML_ASSERT(as_up->ne[3] == 1); // as is 3d (one matrix per expert) @@ -7646,11 +7646,11 @@ struct ggml_tensor * ggml_moe_up_gate( bool is_node = false; - if (as_up->grad || as_gate->grad || b->grad) { + if (as_up->grad || (as_gate && as_gate->grad) || b->grad) { is_node = true; } - const int64_t ne[4] = { as_up->ne[1], ids->ne[0], b->ne[2], 1 }; + const int64_t ne[4] = { as_gate ? as_up->ne[1] : as_up->ne[1]/2, ids->ne[0], b->ne[2], 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); result->op = GGML_OP_MOE_FUSED_UP_GATE; @@ -7681,7 +7681,7 @@ struct ggml_tensor * ggml_moe_up_gate_ext( return ggml_moe_up_gate(ctx, as_up, as_gate, b, ids, op); } - if (as_up->type != as_gate->type || !ggml_are_same_shape(as_up, as_gate)) { + if (as_gate && (as_up->type != as_gate->type || !ggml_are_same_shape(as_up, as_gate))) { struct ggml_tensor * result_up = ggml_mul_mat_id(ctx, as_up, b, ids); if (as_up_b) { result_up = ggml_add_id(ctx, result_up, as_up_b, ids); @@ -7694,7 +7694,7 @@ struct ggml_tensor * ggml_moe_up_gate_ext( } GGML_ASSERT(!ggml_is_transposed(as_up)); - GGML_ASSERT(!ggml_is_transposed(as_gate)); + GGML_ASSERT(!as_gate || !ggml_is_transposed(as_gate)); GGML_ASSERT(ids->type == GGML_TYPE_I32); GGML_ASSERT(as_up->ne[3] == 1); // as is 3d (one matrix per expert) @@ -7705,10 +7705,10 @@ struct ggml_tensor * ggml_moe_up_gate_ext( GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast GGML_ASSERT(as_up->ne[1] == as_up_b->ne[0]); - GGML_ASSERT(as_gate->ne[1] == as_gate_b->ne[0]); + GGML_ASSERT(!as_gate || as_gate->ne[1] == as_gate_b->ne[0]); bool is_node = false; - const int64_t ne[4] = { as_up->ne[1], ids->ne[0], b->ne[2], 1 }; + const int64_t ne[4] = { as_gate ? as_up->ne[1] : as_up->ne[1]/2, ids->ne[0], b->ne[2], 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); result->op = GGML_OP_MOE_FUSED_UP_GATE; @@ -16571,8 +16571,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - GGML_ASSERT(dst->src[0]->type == dst->src[1]->type); - GGML_ASSERT(ggml_are_same_shape(dst->src[0], dst->src[1])); + GGML_ASSERT(!dst->src[1] || dst->src[0]->type == dst->src[1]->type); + GGML_ASSERT(!dst->src[1] || ggml_are_same_shape(dst->src[0], dst->src[1])); GGML_ASSERT(dst->type == GGML_TYPE_F32); const struct ggml_tensor * src1 = dst->src[2]; @@ -16604,7 +16604,7 @@ static void ggml_compute_forward_mul_mat_id_up_gate( GGML_ASSERT(ne13 == 1); const size_t nb41 = up_b ? up_b->nb[1] : 0; - const size_t nb51 = up_b ? gate_b->nb[1] : 0; + const size_t nb51 = up_b && gate_b ? gate_b->nb[1] : 0; // row groups const int n_ids = ids->ne[0]; // n_expert_used @@ -16692,16 +16692,20 @@ static void ggml_compute_forward_mul_mat_id_up_gate( } const char * src0_1_cur = (const char *) src0_1->data + cur_a*nb02; - const char * src0_2_cur = (const char *) src0_2->data + cur_a*nb02; + const char * src0_2_cur = src0_2 ? (const char *) src0_2->data + cur_a*nb02 : src0_1_cur + nb02/2; const char * up_b_cur = up_b ? (const char *)up_b->data + cur_a*nb41 : NULL; const char * gate_b_cur = gate_b ? (const char *)gate_b->data + cur_a*nb51 : NULL; + if (up_b_cur && !gate_b_cur) { + gate_b_cur = up_b_cur + nb41/2; + } const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); - const int64_t nr0 = ne01; // src0 rows + const int64_t nr0 = src0_2 ? ne01 : ne01/2; // src0 rows const int64_t nr1 = cne1; // src1 rows - // + + //if (ith == 0) printf("Calling iqk_moe_fused_up_gate with nr0 = %d, nr1 = %d, ne00 = %d, ne11 = %d\n", (int)nr0, (int)nr1, (int)ne00, (int)ne11); if (!iqk_moe_fused_up_gate(nr0, nr1, ne00, ne11, dst->op_params[0], type, src0_1_cur, src0_2_cur, nb01, vec_dot_type, (const char *)wdata, row_size, @@ -16709,27 +16713,6 @@ static void ggml_compute_forward_mul_mat_id_up_gate( (float *)dst->data, nb1, nb2, matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); -// if (nth%2 == 0) { -// const char * src0_d = ith%2 == 0 ? src0_1_cur : src0_2_cur; -// void * dst_d = ith%2 == 0 ? dst1->data : dst2->data; -// if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, -// type, src0_d, nb01, -// vec_dot_type, (const char *)wdata, row_size, -// (float *)dst_d, nb1, nb2, -// matrix_rows + cur_a*ne12, ith/2, nth/2)) GGML_ABORT("fatal error"); -// -// } else { -// if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, -// src0_1->type, (const char *)src0_1_cur, nb01, -// vec_dot_type, (const char *)wdata, row_size, -// (float *)dst1->data, nb1, nb2, -// matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); -// if (!iqk_mul_mat_moe(nr0, nr1, ne00, ne11, -// src0_2->type, (const char *)src0_2_cur, nb01, -// vec_dot_type, (const char *)wdata, row_size, -// (float *)dst2->data, nb1, nb2, -// matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error"); -// } } #undef MMID_MATRIX_ROW @@ -25193,10 +25176,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa { cur = 0; const struct ggml_tensor * src0 = node->src[0]; + const struct ggml_tensor * src1 = node->src[1]; const struct ggml_tensor * src2 = node->src[2]; const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type; - if (src2->type != vec_dot_type) { - cur += ggml_row_size(vec_dot_type, node->src[1]->ne[0]) * ggml_nrows(node->src[1]); + if (src1 && src1->type != vec_dot_type) { + cur += ggml_row_size(vec_dot_type, src2->ne[0]) * ggml_nrows(src2); } const int n_as = src0->ne[2]; cur += GGML_PAD(cur, sizeof(int64_t)); // align diff --git a/include/llama.h b/include/llama.h index 8364a616..9a52c51c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -392,6 +392,7 @@ extern "C" { bool use_thp; // use transparent huge pages (linux only) bool validate_quants; // if true, check for NaNs while loading the model bool merge_qkv; // if true, merge separate Q, K, V tensors into a single, contiguous tensor + bool merge_up_gate_exps; // if true, merge ffn_up_exps and ffn_gate_exps tensors into a single, contiguous tensor }; // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 7256bc7c..ea6d2dad 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -910,7 +910,8 @@ ggml_tensor * llm_build_context::llm_build_moe_ffn( bool scale_w, float w_scale, llm_expert_gating_func_type gating_op, - const llm_build_cb & cb, int il, ggml_cgraph * graph, bool add_input) { + const llm_build_cb & cb, int il, ggml_cgraph * graph, bool add_input, + ggml_tensor * up_gate_exps, ggml_tensor * up_gate_exps_b) { auto input = cur; @@ -1025,6 +1026,19 @@ llm_expert_gating_func_type gating_op, bool can_use_fmoe = type_op == LLM_FFN_SILU || type_op == LLM_FFN_GELU || type_op == LLM_FFN_SWIGLU_OAI_MOE; ggml_tensor * par; + if (can_use_fmoe && up_gate_exps) { + if (up_gate_exps_b) { + par = ggml_moe_up_gate_ext(ctx, up_gate_exps, nullptr, cur, selected_experts, up_gate_exps_b, nullptr, + type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : + type_op == LLM_FFN_GELU ? GGML_UNARY_OP_GELU : GGML_UNARY_OP_SWIGLU_OAI); + } else { + GGML_ASSERT(type_op != LLM_FFN_SWIGLU_OAI_MOE); + par = ggml_moe_up_gate(ctx, up_gate_exps, nullptr, cur, selected_experts, + type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU); + } + } else { + GGML_ASSERT(!up_gate_exps && !up_gate_exps_b); + if (can_use_fmoe && lctx.cparams.fused_moe_up_gate && up_exps->type == gate_exps->type) { if (up_exps_b || gate_exps_b) { par = ggml_moe_up_gate_ext(ctx, up_exps, gate_exps, cur, selected_experts, up_exps_b, gate_exps_b, @@ -1069,6 +1083,7 @@ llm_expert_gating_func_type gating_op, GGML_ABORT("fatal error"); } + } } cb(par, "ffn_moe_gate_par", il); @@ -1130,7 +1145,8 @@ ggml_tensor * llm_build_context::llm_build_std_moe_ffn(ggml_context * ctx, llama float w_scale, llm_expert_gating_func_type gating_op, llm_ffn_op_type type_op_shexp, - const llm_build_cb & cb, int il, ggml_cgraph * graph, bool add_input) { + const llm_build_cb & cb, int il, ggml_cgraph * graph, bool add_input, + ggml_tensor * up_gate_exps, ggml_tensor * up_gate_exps_b) { auto split_up_exps = (ggml_split_tensor_t *)up_exps->extra; auto split_gate_exps = (ggml_split_tensor_t *)gate_exps->extra; @@ -1164,7 +1180,7 @@ llm_expert_gating_func_type gating_op, the_exp_probs_b, n_expert, n_expert_used, type_op, norm_w, scale_w, w_scale, - gating_op, cb, il, graph, false); + gating_op, cb, il, graph, false, up_gate_exps, up_gate_exps_b); cb(routed_out, "routed_out", il); if (add_input) { routed_out = ggml_add(ctx, routed_out, input); @@ -4047,7 +4063,8 @@ ggml_cgraph * llm_build_context::build_qwen3moe() { n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0f, LLM_EXPERT_GATING_FUNC_SOFTMAX, - LLM_FFN_SILU, cb, il, gf, true); + LLM_FFN_SILU, cb, il, gf, true, + model.layers[il].ffn_up_gate_exps); //printf("%s: ffn = %s(%s)\n", __func__, cur->name, ggml_op_name(cur->op)); @@ -8410,7 +8427,8 @@ ggml_cgraph * llm_build_context::build_openai_moe() { n_expert, n_expert_used, LLM_FFN_SWIGLU_OAI_MOE, false, false, 0.0f, LLM_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT, - LLM_FFN_SWIGLU_OAI_MOE, cb, il, gf, true); + LLM_FFN_SWIGLU_OAI_MOE, cb, il, gf, true, + model.layers[il].ffn_up_gate_exps, model.layers[il].ffn_up_gate_exps_b); cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); diff --git a/src/llama-build-context.h b/src/llama-build-context.h index dda1246a..2cf36ece 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -354,7 +354,8 @@ struct llm_build_context { bool scale_w, float w_scale, llm_expert_gating_func_type gating_op, - const llm_build_cb & cb, int il, ggml_cgraph * graph = nullptr, bool add_input = false); + const llm_build_cb & cb, int il, ggml_cgraph * graph = nullptr, bool add_input = false, + ggml_tensor * up_gate_exps = nullptr, ggml_tensor * up_gate_exps_b = nullptr); static ggml_tensor * llm_build_moe_ffn(ggml_context * ctx, llama_context & lctx, ggml_tensor * cur, @@ -370,7 +371,8 @@ llm_expert_gating_func_type gating_op, bool scale_w, float w_scale, llm_expert_gating_func_type gating_op, - const llm_build_cb & cb, int il, ggml_cgraph * graph = nullptr, bool add_input = false) { + const llm_build_cb & cb, int il, ggml_cgraph * graph = nullptr, bool add_input = false, + ggml_tensor * up_gate_exps = nullptr, ggml_tensor * up_gate_exps_b = nullptr) { return llm_build_moe_ffn(ctx, lctx, cur, gate_inp, nullptr, up_exps, nullptr, @@ -379,7 +381,7 @@ llm_expert_gating_func_type gating_op, exp_probs_b, n_expert, n_expert_used, type_op, norm_w, scale_w, w_scale, - gating_op, cb, il, graph, add_input); + gating_op, cb, il, graph, add_input, up_gate_exps, up_gate_exps_b); } static ggml_tensor * llm_build_std_moe_ffn(ggml_context * ctx, llama_context & lctx, @@ -401,7 +403,8 @@ llm_expert_gating_func_type gating_op, float w_scale, llm_expert_gating_func_type gating_op, llm_ffn_op_type type_op_shexp, - const llm_build_cb & cb, int il, ggml_cgraph * graph, bool add_input = false); + const llm_build_cb & cb, int il, ggml_cgraph * graph, bool add_input = false, + ggml_tensor * up_gate_exps = nullptr, ggml_tensor * up_gate_exps_b = nullptr); static ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids); diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index d7147033..fcddc183 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -31,6 +31,8 @@ struct create_tensors_helper : public create_tensors_helper_interface { bool merge_qkv(const LLM_TN & tn, int i, int bias, bool ignore_attn_scale = false); + bool merge_up_gate_exps(const LLM_TN & tn, int i, int bias); + bool create_tensors() override; bool create_llama_tensors(const LLM_TN & tn); @@ -141,6 +143,8 @@ struct create_tensors_helper : public create_tensors_helper_interface { ggml_tensor * create_tensor(ggml_context * ctx, const std::string & name, const std::vector & ne, int flags = 0, ggml_context ** actual_ctx = nullptr); + ggml_context * get_context_for_tensor(ggml_context * ctx, const std::string & name); + void create_default_embd_output(const LLM_TN & tn, int n_embd, int n_vocab, bool norm_bias); void create_embd_output(const LLM_TN & tn, int n_embd, int n_vocab, bool has_norm = true, bool use_ctx_split = false); @@ -195,7 +199,10 @@ create_tensors_helper::create_tensors_helper(llama_model_loader & _ml, llama_mod buft_layer_count[model.buft_layer[i].buft_matrix]++; } - ctx_size = ggml_tensor_overhead()*(ml.n_tensors + 1); // +1 for models where tok_embd is duplicated as output + auto n_tensors = ml.n_tensors; + if (ml.merge_qkv) n_tensors += n_layer; + if (ml.merge_up_gate_exps) n_tensors += n_layer; + ctx_size = ggml_tensor_overhead()*(n_tensors + 1); // +1 for models where tok_embd is duplicated as output ctx_size += ggml_tensor_overhead()*n_layer*3; // for moe merged tensors if (model.splits.size() > 1) { @@ -288,9 +295,7 @@ static std::vector create_split(int nr, int granularity, const std::vector< return result; } -ggml_tensor * create_tensors_helper::create_tensor(ggml_context * ctx, const std::string & name, const std::vector & ne, - int flags, ggml_context ** actual_context) { - //auto requested_ctx = ctx; +ggml_context * create_tensors_helper::get_context_for_tensor(ggml_context * ctx, const std::string & name) { if (ml.tensor_buft_overrides) { for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) { std::regex pattern(overrides->pattern); @@ -301,6 +306,12 @@ ggml_tensor * create_tensors_helper::create_tensor(ggml_context * ctx, const std } } } + return ctx; +} + +ggml_tensor * create_tensors_helper::create_tensor(ggml_context * ctx, const std::string & name, const std::vector & ne, + int flags, ggml_context ** actual_context) { + ctx = get_context_for_tensor(ctx, name); if (actual_context) *actual_context = ctx; auto tensor = ml.create_tensor(ctx, name, ne, flags); if (tensor && ctx == split_ctx) { @@ -1168,9 +1179,14 @@ bool create_tensors_helper::create_qwen3_moe_tensors(const LLM_TN & tn) { // MoE branch const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - layer.ffn_gate_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + bool merged = ml.merge_up_gate_exps && merge_up_gate_exps(tn, i, 0); + if (merged) { + use_mmap_buffer = false; + } else { + layer.ffn_up_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + layer.ffn_gate_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + } layer.ffn_down_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}); - layer.ffn_up_exps = create_tensor(ffn_ctx, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); } return use_mmap_buffer; } @@ -2572,9 +2588,18 @@ bool create_tensors_helper::create_openai_moe_tensors(const LLM_TN & tn) { ggml_context *ctx_ffn_gate, *ctx_ffn_up, *ctx_ffn_down; layer.ffn_gate_inp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0, &ctx_ffn_gate); - layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0, &ctx_ffn_down); - layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0, &ctx_ffn_up); + bool merged = ml.merge_up_gate_exps && merge_up_gate_exps(tn, i, 2); + use_mmap_buffer &= !merged; + if (merged) { + ctx_ffn_gate = ctx_ffn_up = ctx_split; + } else { + layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), + { n_embd, n_ff_exp, n_expert}, 0, &ctx_ffn_up); + layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), + { n_embd, n_ff_exp, n_expert}, 0, &ctx_ffn_gate); + } + layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), + {n_ff_exp, n_embd, n_expert}, 0, &ctx_ffn_down); // bias layer.ffn_gate_inp_b = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0); @@ -2582,15 +2607,17 @@ bool create_tensors_helper::create_openai_moe_tensors(const LLM_TN & tn) { auto ctx_gate_b = ctx_ffn_gate == ctx_split ? ctx_split : ctx_layer; auto ctx_down_b = ctx_ffn_down == ctx_split ? ctx_split : ctx_layer; auto ctx_up_b = ctx_ffn_up == ctx_split ? ctx_split : ctx_layer; - layer.ffn_gate_exps_b = create_tensor(ctx_gate_b, tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0, &ctx_ffn_gate_b); + if (!merged) { + layer.ffn_up_exps_b = create_tensor(ctx_up_b, tn(LLM_TENSOR_FFN_UP_EXPS, "bias", i), {n_ff_exp, n_expert}, 0, &ctx_ffn_up_b); + layer.ffn_gate_exps_b = create_tensor(ctx_gate_b, tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0, &ctx_ffn_gate_b); + } layer.ffn_down_exps_b = create_tensor(ctx_down_b, tn(LLM_TENSOR_FFN_DOWN_EXPS, "bias", i), { n_embd, n_expert}, 0, &ctx_ffn_down_b); - layer.ffn_up_exps_b = create_tensor(ctx_up_b, tn(LLM_TENSOR_FFN_UP_EXPS, "bias", i), {n_ff_exp, n_expert}, 0, &ctx_ffn_up_b); - if (ctx_ffn_gate_b != ctx_ffn_gate) { + if (!merged && ctx_ffn_gate_b != ctx_ffn_gate) { layer.ffn_gate_exps_b_dup = create_tensor(ctx_ffn_gate, tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, llama_model_loader::TENSOR_DUPLICATED); } - if (ctx_ffn_up_b != ctx_ffn_up) { + if (!merged && ctx_ffn_up_b != ctx_ffn_up) { layer.ffn_up_exps_b_dup = create_tensor(ctx_ffn_up, tn(LLM_TENSOR_FFN_UP_EXPS, "bias", i), {n_ff_exp, n_expert}, llama_model_loader::TENSOR_DUPLICATED); } @@ -2654,6 +2681,71 @@ bool create_tensors_helper::create_smollm3_tensors(const LLM_TN & tn) { return use_mmap_buffer; } +bool create_tensors_helper::merge_up_gate_exps(const LLM_TN & tn, int i, int bias) { + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + auto u_name = tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i); + auto g_name = tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i); + auto u_meta = ml.require_tensor_meta(u_name.c_str()); + auto g_meta = ml.require_tensor_meta(g_name.c_str()); + + if (u_meta->type != g_meta->type || u_meta->ne[0] != g_meta->ne[0] || u_meta->ne[2] != g_meta->ne[2]) { + printf("%s: not merging because up/fate meta info is different\n", __func__); + return false; + } + + auto u_ctx = get_context_for_tensor(ctx_split, u_name); + auto g_ctx = get_context_for_tensor(ctx_split, g_name); + + if (u_ctx != g_ctx) { + printf("%s: not merging because of context\n", __func__); + return false; + } + + if (bias && (u_ctx != ctx_split || g_ctx != ctx_split)) { + printf("%s: not merging because of context\n", __func__); + return false; + } + + printf("%s: merging up/gate in layer %d\n", __func__, i); + + layer.ffn_up_gate_exps = ggml_new_tensor_3d(u_ctx, u_meta->type, u_meta->ne[0], u_meta->ne[1] + g_meta->ne[1], u_meta->ne[2]); + snprintf(layer.ffn_up_gate_exps->name, GGML_MAX_NAME, "blk.%d.ffn_up_gate_exps.weight", i); + layer.ffn_up_exps = ml.create_tensor_as_view(u_ctx, layer.ffn_up_gate_exps, u_name.c_str(), + { u_meta->ne[0], u_meta->ne[1], u_meta->ne[2] }, 0); + layer.ffn_gate_exps = ml.create_tensor_as_view(u_ctx, layer.ffn_up_gate_exps, g_name.c_str(), + { g_meta->ne[0], g_meta->ne[1], g_meta->ne[2] }, ggml_nbytes(layer.ffn_up_exps) ); //u_meta->ne[1]*u_meta->nb[1] ); + + if (!bias) return true; + + auto u_name_b = tn(LLM_TENSOR_FFN_UP_EXPS, "bias", i); + auto g_name_b = tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i); + auto u_meta_b = ml.get_tensor_meta(u_name_b.c_str()); + auto g_meta_b = ml.get_tensor_meta(g_name_b.c_str()); + if (bias == 2) { + GGML_ASSERT(u_meta_b && g_meta_b); + GGML_ASSERT(u_meta_b->type == g_meta_b->type); + GGML_ASSERT(u_meta_b->ne[1] == g_meta_b->ne[1]); + } else { + GGML_ASSERT(!u_meta_b && !g_meta_b); + return true; + } + + GGML_ASSERT(u_meta->ne[1] == u_meta_b->ne[0]); + GGML_ASSERT(g_meta->ne[1] == g_meta_b->ne[0]); + + layer.ffn_up_gate_exps_b = ggml_new_tensor_2d(ctx_split, u_meta_b->type, u_meta_b->ne[0] + g_meta_b->ne[0], u_meta->ne[1]); + snprintf(layer.ffn_up_gate_exps_b->name, GGML_MAX_NAME, "blk.%d.ffn_up_gate_exps.bias", i); + layer.ffn_up_exps_b = ml.create_tensor_as_view(ctx_split, layer.ffn_up_gate_exps_b, u_name_b.c_str(), + { u_meta_b->ne[0], u_meta_b->ne[1] }, 0); + layer.ffn_gate_exps_b = ml.create_tensor_as_view(ctx_split, layer.ffn_up_gate_exps_b, g_name_b.c_str(), + { g_meta_b->ne[0], g_meta_b->ne[1] }, ggml_nbytes(layer.ffn_up_exps_b) ); //u_meta->nb[1]); + + return true; +} + bool create_tensors_helper::merge_qkv(const LLM_TN & tn, int i, int bias, bool ignore_attn_scale) { auto& hparams = model.hparams; const int64_t n_head = hparams.n_head(); @@ -2849,11 +2941,18 @@ bool create_tensors_helper::create_tensors() { bool use_mmap_buffer = true; if (ml.merge_qkv && (model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN)) { LLAMA_LOG_WARN("\n========================================================\n"); - LLAMA_LOG_WARN("merge_qkv is not compatible with split model 'graph'\n"); + LLAMA_LOG_WARN("merge_qkv is not compatible with split mode 'graph'\n"); LLAMA_LOG_WARN(" => turning off merge_qkv\n"); LLAMA_LOG_WARN("========================================================\n\n"); ml.merge_qkv = false; } + if (ml.merge_up_gate_exps && (model.split_mode == LLAMA_SPLIT_MODE_GRAPH || model.split_mode == LLAMA_SPLIT_MODE_ATTN)) { + LLAMA_LOG_WARN("\n========================================================\n"); + LLAMA_LOG_WARN("merge_up_gate_exps is not compatible with split mode 'graph'\n"); + LLAMA_LOG_WARN(" => turning off merge_up_gate_exps\n"); + LLAMA_LOG_WARN("========================================================\n\n"); + ml.merge_up_gate_exps = false; + } switch (model.arch) { case LLM_ARCH_LLAMA: case LLM_ARCH_REFACT: diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index d7c68b33..4e200211 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -204,7 +204,7 @@ namespace GGUFMeta { } llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, - bool repack_tensors, bool use_thp, bool merge_qkv, + bool repack_tensors, bool use_thp, bool merge_qkv, bool merge_up_gate_exps, const llama_model_kv_override * param_overrides_p, const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) { int trace = 0; @@ -497,6 +497,7 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, this->repack_tensors = repack_tensors; this->use_thp = use_thp; this->merge_qkv = merge_qkv; + this->merge_up_gate_exps = merge_up_gate_exps; } llama_model_loader::~llama_model_loader() { diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h index 366dea41..c59eaf4f 100644 --- a/src/llama-model-loader.h +++ b/src/llama-model-loader.h @@ -45,6 +45,7 @@ struct llama_model_loader { bool repack_tensors = false; bool use_thp = false; bool merge_qkv = false; + bool merge_up_gate_exps = false; llama_files files; llama_ftype ftype; @@ -79,7 +80,8 @@ struct llama_model_loader { std::string arch_name; LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); - llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, bool use_thp, bool merge_qkv, + llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, bool repack_tensors, bool use_thp, + bool merge_qkv, bool merge_up_gate_exps, const llama_model_kv_override * param_overrides_p, const llama_model_tensor_buft_override * param_tensor_buft_overrides_p); diff --git a/src/llama-model.h b/src/llama-model.h index deb3b0c1..4667193e 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -236,6 +236,7 @@ struct llama_layer { struct ggml_tensor * ffn_gate_exps = nullptr; struct ggml_tensor * ffn_down_exps = nullptr; struct ggml_tensor * ffn_up_exps = nullptr; + struct ggml_tensor * ffn_up_gate_exps = nullptr; llama_split_tensor split_ffn_gate_inp; llama_split_tensor split_ffn_up_exps; @@ -247,6 +248,7 @@ struct llama_layer { struct ggml_tensor * ffn_gate_exps_b = nullptr; struct ggml_tensor * ffn_down_exps_b = nullptr; struct ggml_tensor * ffn_up_exps_b = nullptr; + struct ggml_tensor * ffn_up_gate_exps_b = nullptr; struct ggml_tensor * ffn_gate_exps_b_dup = nullptr; struct ggml_tensor * ffn_down_exps_b_dup = nullptr; struct ggml_tensor * ffn_up_exps_b_dup = nullptr; diff --git a/src/llama-quantize.cpp b/src/llama-quantize.cpp index c9938b38..927c3f31 100644 --- a/src/llama-quantize.cpp +++ b/src/llama-quantize.cpp @@ -1009,7 +1009,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s kv_overrides = v->data(); } llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, /* repack_tensors */ false, - /* use_thp */ false, /* merge_qkv */ false, kv_overrides, nullptr); + /* use_thp */ false, /* merge_qkv */ false, /* merge_up_gate_exps */ false, kv_overrides, nullptr); ml.init_mappings(false); // no prefetching llama_model model; diff --git a/src/llama.cpp b/src/llama.cpp index 10db6d58..053cc42e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1882,6 +1882,9 @@ static bool llm_load_tensors( } use_mmap_buffer = cth->create_tensors(); + if (!use_mmap_buffer) { + ml.use_mmap = false; + } ml.done_getting_tensors(); @@ -2104,7 +2107,8 @@ static bool llm_load_tensors( static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) { try { llama_model_loader ml(fname, params.use_mmap, params.check_tensors, - params.repack_tensors, params.use_thp, params.merge_qkv, params.kv_overrides, params.tensor_buft_overrides); + params.repack_tensors, params.use_thp, params.merge_qkv, params.merge_up_gate_exps, + params.kv_overrides, params.tensor_buft_overrides); model.hparams.vocab_only = params.vocab_only; @@ -4017,6 +4021,7 @@ struct llama_model_params llama_model_default_params() { /*.use_thp =*/ false, /*.validate_quants =*/ false, /*.merge_qkv =*/ false, + /*.merge_up_gate_exps =*/ false, }; #ifdef GGML_USE_METAL @@ -4286,6 +4291,80 @@ void llama_free_model(struct llama_model * model) { delete model; } +static void llama_repack_up_gate_exps(llama_context & lctx) { + auto & model = lctx.model; + bool needs_repack = false; + for (auto & l : model.layers) { + if (l.ffn_up_gate_exps && l.ffn_up_exps && l.ffn_gate_exps) { + needs_repack = true; break; + } + } + if (!needs_repack) return; + + std::vector aux_buffer_up, aux_buffer_gate, aux_buffer_up_gate; + for (int il = 0; il < int(model.layers.size()); ++il) { + auto & l = model.layers[il]; + if (l.ffn_up_gate_exps && l.ffn_up_exps && l.ffn_gate_exps) { + GGML_ASSERT(l.ffn_up_gate_exps->type == l.ffn_up_exps->type && l.ffn_up_gate_exps->type == l.ffn_gate_exps->type); + GGML_ASSERT(l.ffn_up_gate_exps->ne[0] == l.ffn_up_exps->ne[0] && l.ffn_up_gate_exps->ne[0] == l.ffn_gate_exps->ne[0]); + GGML_ASSERT(l.ffn_up_gate_exps->ne[2] == l.ffn_up_exps->ne[2] && l.ffn_up_gate_exps->ne[2] == l.ffn_gate_exps->ne[2]); + GGML_ASSERT(l.ffn_up_gate_exps->ne[1] == l.ffn_up_exps->ne[1] + l.ffn_gate_exps->ne[1]); + auto nbytes = ggml_nbytes(l.ffn_up_exps); + GGML_ASSERT(nbytes == ggml_nbytes(l.ffn_gate_exps)); + if (nbytes > aux_buffer_up.size()) { + aux_buffer_up.resize(nbytes); + } + if (nbytes > aux_buffer_gate.size()) { + aux_buffer_gate.resize(nbytes); + } + printf("%s: repacking up/gate experts weight in layer %d\n", __func__, il); + ggml_backend_tensor_get(l.ffn_up_exps, aux_buffer_up.data(), 0, nbytes); + ggml_backend_tensor_get(l.ffn_gate_exps, aux_buffer_gate.data(), 0, nbytes); + if (aux_buffer_up_gate.size() < 2*nbytes) { + aux_buffer_up_gate.resize(2*nbytes); + } + size_t offset_up_gate = 0; + size_t offset_up = 0; + auto expert_size = l.ffn_up_exps->ne[1]*l.ffn_up_exps->nb[1]; + for (int i2 = 0; i2 < (int)l.ffn_up_gate_exps->ne[2]; ++i2) { + std::memcpy(aux_buffer_up_gate.data() + offset_up_gate, aux_buffer_up.data() + offset_up, expert_size); + offset_up_gate += expert_size; + std::memcpy(aux_buffer_up_gate.data() + offset_up_gate, aux_buffer_gate.data() + offset_up, expert_size); + offset_up_gate += expert_size; + offset_up += expert_size; + } + ggml_backend_tensor_set(l.ffn_up_gate_exps, aux_buffer_up_gate.data(), 0, 2*expert_size*l.ffn_up_gate_exps->ne[2]); + if (l.ffn_up_gate_exps_b && l.ffn_up_exps_b && l.ffn_gate_exps_b) { + nbytes = ggml_nbytes(l.ffn_up_exps_b); + GGML_ASSERT(nbytes == ggml_nbytes(l.ffn_gate_exps_b)); + if (nbytes > aux_buffer_up.size()) { + aux_buffer_up.resize(nbytes); + } + if (nbytes > aux_buffer_gate.size()) { + aux_buffer_gate.resize(nbytes); + } + printf("%s: repacking up/gate experts bias in layer %d\n", __func__, il); + ggml_backend_tensor_get(l.ffn_up_exps_b, aux_buffer_up.data(), 0, nbytes); + ggml_backend_tensor_get(l.ffn_gate_exps_b, aux_buffer_gate.data(), 0, nbytes); + if (aux_buffer_up_gate.size() < 2*nbytes) { + aux_buffer_up_gate.resize(2*nbytes); + } + offset_up_gate = 0; + offset_up = 0; + expert_size = l.ffn_up_exps_b->nb[1]; + for (int i1 = 0; i1 < (int)l.ffn_up_gate_exps_b->ne[1]; ++i1) { + std::memcpy(aux_buffer_up_gate.data() + offset_up_gate, aux_buffer_up.data() + offset_up, expert_size); + offset_up_gate += expert_size; + std::memcpy(aux_buffer_up_gate.data() + offset_up_gate, aux_buffer_gate.data() + offset_up, expert_size); + offset_up_gate += expert_size; + offset_up += expert_size; + } + ggml_backend_tensor_set(l.ffn_up_gate_exps_b, aux_buffer_up_gate.data(), 0, 2*expert_size*l.ffn_up_gate_exps_b->ne[1]); + } + } + } +} + struct llama_context * llama_new_context_with_model( struct llama_model * model, struct llama_context_params params) { @@ -4748,6 +4827,8 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(ctx->sched)); } + llama_repack_up_gate_exps(*ctx); + // build worst-case graph int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_ubatch); int n_past = cparams.n_ctx - n_tokens;