From 519405dc979de59d651c4f1f3e42ec4d11606fd1 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 27 Dec 2025 08:18:06 +0100 Subject: [PATCH] Async compute graph evaluation (2 or more GPUs) (#1089) * WIP: absorb adding input into std_attn and std_ffn * WIP: NCCL infra * WIP: add reduce and fake_cpy ops * WIP * WIP: graph appears to work, layer is broken * WIP: Qwen3-MoE works with graph, layer still broken * WIP: GLM-4.5 graph works * WIP: fix sm layer (dense) * WIP: fix sm layer (MoE) * WIP: fast PP with bespoke 4-GPU NCCL I guess, I'm not using NCCL the right way as PP is very low with a single communicator group for 3 or more GPUs. But if I create 4 communicator groups for pairs of GPUs (0,1, 2,3, 0,2, 1,3) and use that, PP is fast: I'm hitting 1500 t/s for L3-70B on the 4x3090 system, which is ~20% better than the previous sm graph without NCCL. But that cannot be the solution (I cannot be creating pairwise communicators and associated logic for every possible number of GPUs). * WIP: Cohere2 * Explicitely set device * Bespoke 3-GPU case * WIP * Do not repeat get_rows multiple times * Fix 3 GPUs * OK, let's leave it in * Simple async * This sync seems enough * Only do async for 4 or more backends With 2 GPUs (so, 3 backends) not using async is slightly faster * Scheduler changes * Use OpenMP if available Surprisingly (at least to me), this is quite a bit faster than std::thread and std::barrier. GLM-4.5-AIR with 4 GPUs is now at 105 t/s at zero context! * Do not use OpenMP if there are tensor overrides * Set omp max active levels * Be more careful with having set the device before using a stream * Command line option to turn on async. Set to false by defualt for now --------- Co-authored-by: Iwan Kawrakow --- CMakeLists.txt | 2 +- common/common.cpp | 7 + common/common.h | 1 + ggml/include/ggml-backend.h | 2 +- ggml/src/ggml-backend.cpp | 424 ++++++++++++++++++++++++++---------- include/llama.h | 1 + src/llama-build-context.cpp | 6 - src/llama-cparams.h | 1 + src/llama.cpp | 7 +- tests/CMakeLists.txt | 2 +- 10 files changed, 321 insertions(+), 132 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index bbb2c991..acf17f9b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,7 +6,7 @@ include(CheckIncludeFileCXX) set(CMAKE_WARN_UNUSED_CLI YES) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED true) set(CMAKE_CUDA_USE_RESPONSE_FILE_FOR_INCLUDES 0) diff --git a/common/common.cpp b/common/common.cpp index 626da8d6..0d9d12c8 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1436,6 +1436,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.split_mode_graph_scheduling = true; return true; } + if (arg == "-sas" || arg == "--scheduler-async") { + params.scheduler_async = true; + return true; + } if (arg == "-smf16" || arg == "--split-mode-f16") { params.split_mode_f16 = true; return true; @@ -2133,6 +2137,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-smf16, --split-mode-f16,", "Use f16 for data exchange between GPUs (default: %d)", params.split_mode_f16}); options.push_back({ "*", "-smf32, --split-mode-f32,", "Use f32 for data exchange between GPUs (default: %d)", !params.split_mode_f16}); options.push_back({ "*", "-smgs, --split-mode-graph-scheduling,", "Force Split Mode Graph Scheduling (default: %d)", params.split_mode_graph_scheduling}); + options.push_back({ "*", "-sas, ==scheduler_async,", "Async evaluation of compute graphs: %d)", params.scheduler_async}); options.push_back({ "*", "-vq, --validate-quants", "validate quantized data while loading the model (default: %d)", params.validate_quants}); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" "in conversation mode, this will be used as system prompt\n" @@ -3167,6 +3172,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.k_cache_hadamard = params.k_cache_hadamard; cparams.split_mode_graph_scheduling = params.split_mode_graph_scheduling; cparams.split_mode_f16 = params.split_mode_f16; + cparams.scheduler_async = params.scheduler_async; cparams.min_experts = params.min_experts; cparams.thresh_experts = params.thresh_experts; cparams.only_active_experts = params.only_active_exps; @@ -4150,6 +4156,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "k_cache_hadamard: %s # default: false\n", params.k_cache_hadamard ? "true" : "false"); fprintf(stream, "split_mode_graph_scheduling: %s # default: false\n", params.split_mode_graph_scheduling ? "true" : "false"); fprintf(stream, "split_mode_f16: %s # default: true\n", params.split_mode_f16 ? "true" : "false"); + fprintf(stream, "scheduler_async: %s # default: false\n", params.scheduler_async ? "true" : "false"); fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); diff --git a/common/common.h b/common/common.h index 8fe2287b..a5b231fe 100644 --- a/common/common.h +++ b/common/common.h @@ -290,6 +290,7 @@ struct gpt_params { bool k_cache_hadamard = false; // if true, use Hadamard transform for the K-cache (only makes sense with quantized cache) bool split_mode_graph_scheduling = false; // if true, force split mode graph scheduling bool split_mode_f16 = true; // if true, intermediate results will be cast to f16 before copying to other GPUs to perform reduce ops + bool scheduler_async = false; // if true, in split mode graph the scheduler will use multiple threads to evaluate the graph std::string cache_type_k = "f16"; // KV cache data type for the K std::string cache_type_v = "f16"; // KV cache data type for the V diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 82f05092..e75606dc 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -211,7 +211,7 @@ extern "C" { // enable or disable op offload for a given op GGML_API void ggml_backend_sched_set_op_offload(ggml_backend_sched_t sched, enum ggml_op op, bool on_or_off); GGML_API void ggml_backend_sched_set_only_active_experts(ggml_backend_sched_t sched, bool on_or_off); - GGML_API void ggml_backend_sched_set_split_mode_graph(ggml_backend_sched_t sched, bool on_or_off); + GGML_API void ggml_backend_sched_set_split_mode_graph(ggml_backend_sched_t sched, bool on_or_off, bool async); GGML_API void ggml_backend_sched_set_max_extra_alloc(ggml_backend_sched_t sched, int extra_alloc_MiB); // diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 65739cf3..345eddd2 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -14,6 +14,11 @@ #include #include #include +#include +#include +#ifdef GGML_USE_OPENMP +#include +#endif #define IK_PRINT_TIMING 0 @@ -1169,9 +1174,17 @@ struct ggml_backend_sched { uint32_t op_offload[(GGML_OP_COUNT + 31)/32]; + std::vector workers; + std::vector statuses; + std::vector> backend_splits; + std::array needs_sync; + std::array own_cpy; + bool only_active_experts; bool split_mode_graph; + bool is_async = false; bool debug; + bool has_reduce = false; }; void ggml_backend_sched_set_op_offload(ggml_backend_sched_t sched, enum ggml_op op, bool on_or_off) { @@ -1196,9 +1209,10 @@ void ggml_backend_sched_set_only_active_experts(ggml_backend_sched_t sched, bool sched->only_active_experts = on_or_off; } -void ggml_backend_sched_set_split_mode_graph(ggml_backend_sched_t sched, bool on_or_off) { +void ggml_backend_sched_set_split_mode_graph(ggml_backend_sched_t sched, bool on_or_off, bool async) { if (!sched) return; sched->split_mode_graph = on_or_off; + sched->is_async = async; } void ggml_backend_sched_set_max_extra_alloc(ggml_backend_sched_t sched, int extra_alloc_MiB) { @@ -1393,6 +1407,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg sched->n_splits = 0; sched->n_graph_inputs = 0; sched->is_reset = false; + sched->has_reduce = false; struct ggml_init_params params = { /* .mem_size = */ sched->context_buffer_size, @@ -1697,6 +1712,9 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg // check if we should start a new split based on the sources of the current node bool need_new_split = false; + if (node->op == GGML_OP_REDUCE) { + sched->has_reduce = true; + } if ((node->op == GGML_OP_ADD && node->op_params[0] == 0xff) || node->op == GGML_OP_REDUCE || node->op == GGML_OP_FAKE_CPY || @@ -2083,89 +2101,206 @@ static void ggml_backend_sched_copy_inputs(ggml_backend_sched_t sched, ggml_back } } +static ggml_status ggml_backend_sched_eval(ggml_backend_sched_t sched, ggml_backend_t split_backend, ggml_backend_sched_split * split) { + if (!sched->callback_eval) { +#if IK_PRINT_TIMING + int64_t tim2 = ggml_time_us(); + printf("%s(.1.): %d us\n", __func__, (int)(tim2-tim1)); +#endif + enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph); + if (ec != GGML_STATUS_SUCCESS) { + return ec; + } + } else { + // similar to ggml_backend_compare_graph_backend + for (int j0 = 0; j0 < split->graph.n_nodes; j0++) { + struct ggml_tensor * t = split->graph.nodes[j0]; + + // check if the user needs data from this node + bool need = sched->callback_eval(t, true, sched->callback_eval_user_data); + + int j1 = j0; + + // determine the range [j0, j1] of nodes that can be computed together + while (!need && j1 < split->graph.n_nodes - 1) { + t = split->graph.nodes[++j1]; + need = sched->callback_eval(t, true, sched->callback_eval_user_data); + } + + struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1); + +#if IK_PRINT_TIMING + int64_t tim2 = ggml_time_us(); + printf("%s(.2.): %d us\n", __func__, (int)(tim2-tim1)); +#endif + + enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &gv); + if (ec != GGML_STATUS_SUCCESS) { + return ec; + } + + // TODO: pass backend to the callback, then the user can decide if they want to synchronize + ggml_backend_synchronize(split_backend); + + if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) { + break; + } + + j0 = j1; + } + } + return GGML_STATUS_SUCCESS; +} + static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) { - std::array needs_sync{{true}}; - std::array own_cpy{{false}}; + for (auto & item : sched->needs_sync) item = true; - if (sched->split_mode_graph) { - auto tensor_size = [] (const ggml_tensor * t) { - auto nbytes = ggml_nbytes(t); - nbytes = 256*((nbytes + 255)/256); - return nbytes; - }; - //auto tim1 = std::chrono::steady_clock::now(); - std::vector> backend_splits(sched->n_backends); - for (int i = 0; i < sched->n_splits; i++) { - backend_splits[sched->splits[i].backend_id].push_back(&sched->splits[i]); + if (sched->is_async && sched->n_backends > 2 && sched->split_mode_graph && sched->has_reduce) { + + for (auto & s : sched->statuses) s = GGML_STATUS_SUCCESS; + + bool work_done = false; +#ifdef GGML_USE_OPENMP + if (int nlevels = omp_get_max_active_levels(); nlevels < 2) { + omp_set_max_active_levels(nlevels+1); + //printf("%s: Setting omp max active levels to 2\n", __func__); } - for (int backend_id = 0; backend_id < sched->n_backends; ++backend_id) { - if (ggml_backend_is_cpu(ggml_backend_sched_get_backend(sched, backend_id))) continue; - if (backend_splits[backend_id].empty()) continue; - size_t input_size = 0; - size_t max_input_size = 0; - int last_split = 0; - bool can_alloc = true; - for (int i = 0; i < int(backend_splits[backend_id].size()); ++i) { - auto split = backend_splits[backend_id][i]; - if (split->n_inputs < 1) continue; - size_t this_size = 0; - for (int j = 0; j < split->n_inputs; ++j) { - if (!ggml_backend_buffer_is_host(split->inputs[j]->buffer)) { - this_size += tensor_size(split->inputs[j]); - } - } - if (input_size + this_size > sched->max_extra_alloc) { - if (i - last_split < 3) { - can_alloc = false; + bool has_cpu_work = false; + for (int i = 0; i < sched->n_backends; ++i) { + if (!sched->backend_splits[i].empty()) { + auto split = sched->backend_splits[i].front(); + if (ggml_backend_is_cpu(sched->backends[split->backend_id])) { + //printf("CPU backend %d has %d splits\n", split->backend_id, (int)sched->backend_splits[i].size()); + if (sched->backend_splits[i].size() > 1) { + has_cpu_work = true; break; } - max_input_size = std::max(max_input_size, input_size); - input_size = 0; - last_split = i - 1; } - input_size += this_size; } - max_input_size = std::max(max_input_size, input_size); - if (!can_alloc || !max_input_size) continue; - if (sched->input_memory_bufs[backend_id] && sched->input_memory_bufs[backend_id]->size < max_input_size) { - ggml_backend_buffer_free(sched->input_memory_bufs[backend_id]); - sched->input_memory_bufs[backend_id] = nullptr; - } - if (!sched->input_memory_bufs[backend_id]) { - sched->input_memory_bufs[backend_id] = ggml_backend_alloc_buffer(sched->backends[backend_id], max_input_size); - } - auto ptr = (char *)ggml_backend_buffer_get_base(sched->input_memory_bufs[backend_id]); - input_size = 0; - for (int i = 0; i < int(backend_splits[backend_id].size()); ++i) { - auto split = backend_splits[backend_id][i]; - size_t this_size = 0; - for (int j = 0; j < split->n_inputs; ++j) { - if (!ggml_backend_buffer_is_host(split->inputs[j]->buffer)) { - this_size += tensor_size(split->inputs[j]); - } + } + if (!has_cpu_work) { + #pragma omp parallel num_threads(sched->n_backends) + { + + int ith = omp_get_thread_num(); + + struct ggml_backend_sched_split * splits = sched->splits; + + std::vector ids; + std::vector unique_ids; + ggml_tensor * last_ids_tensor = nullptr; + + for (int i = 0; i < sched->n_splits; i++) { +#if IK_PRINT_TIMING + int64_t tim1 = ggml_time_us(); +#endif + struct ggml_backend_sched_split * split = &splits[i]; + int split_backend_id = split->backend_id; + ggml_backend_t split_backend = sched->backends[split_backend_id]; + + bool needs_barrier = split->n_inputs > 0 || split->graph.nodes[0]->op == GGML_OP_REDUCE; + + if (needs_barrier) { + #pragma omp barrier } - if (input_size + this_size > max_input_size) { - ptr = (char *)ggml_backend_buffer_get_base(sched->input_memory_bufs[backend_id]); - input_size = 0; - } - for (int j = 0; j < split->n_inputs; ++j) { - if (ggml_backend_buffer_is_host(split->inputs[j]->buffer)) continue; - auto input_cpy = tensor_copy(split->inputs[j], backend_id, sched->cur_copy); - for (int k = 0; k < split->graph.n_nodes; ++k) { - auto node = split->graph.nodes[k]; - for (int l = 0; l < GGML_MAX_SRC; ++l) { - if (node->src[l] && node->src[l]->data == input_cpy->data) node->src[l]->data = ptr; + + if (ith == split_backend_id) { + // copy the input tensors to the split backend + ggml_backend_sched_copy_inputs(sched, split, sched->needs_sync, ids, unique_ids, last_ids_tensor); + + if (split->n_inputs > 0 && !sched->own_cpy[split_backend_id]) { + sched->needs_sync[split_backend_id] = true; + } else { + for (int j = 0; j < split->n_inputs; ++j) { + if (ggml_backend_buffer_is_host(split->inputs[j]->buffer)) { + sched->needs_sync[split_backend_id] = true; + } } } - input_cpy->data = ptr; - ptr += tensor_size(split->inputs[j]); + sched->statuses[ith] = ggml_backend_sched_eval(sched, split_backend, split); + } + + if (split->graph.nodes[0]->op == GGML_OP_REDUCE) { + #pragma omp barrier + } + + // record the event of this copy + if (split->n_inputs > 0) { + if (sched->events[split_backend_id][sched->cur_copy] != NULL) { + ggml_backend_event_record(sched->events[split_backend_id][sched->cur_copy]); + } } - input_size += this_size; } - needs_sync[backend_id] = false; - own_cpy[backend_id] = true; } + work_done = true; + } +#endif + if (!work_done) { + std::barrier barrier(sched->n_backends, [] () {}); + auto compute = [sched, &barrier] (int ith) { + + struct ggml_backend_sched_split * splits = sched->splits; + + std::vector ids; + std::vector unique_ids; + ggml_tensor * last_ids_tensor = nullptr; + + for (int i = 0; i < sched->n_splits; i++) { +#if IK_PRINT_TIMING + int64_t tim1 = ggml_time_us(); +#endif + struct ggml_backend_sched_split * split = &splits[i]; + int split_backend_id = split->backend_id; + ggml_backend_t split_backend = sched->backends[split_backend_id]; + + bool needs_barrier = split->n_inputs > 0 || split->graph.nodes[0]->op == GGML_OP_REDUCE; + + if (needs_barrier) { + barrier.arrive_and_wait(); + } + + if (ith == split_backend_id) { + // copy the input tensors to the split backend + ggml_backend_sched_copy_inputs(sched, split, sched->needs_sync, ids, unique_ids, last_ids_tensor); + + if (split->n_inputs > 0 && !sched->own_cpy[split_backend_id]) { + sched->needs_sync[split_backend_id] = true; + } else { + for (int j = 0; j < split->n_inputs; ++j) { + if (ggml_backend_buffer_is_host(split->inputs[j]->buffer)) { + sched->needs_sync[split_backend_id] = true; + } + } + } + sched->statuses[ith] = ggml_backend_sched_eval(sched, split_backend, split); + } + + if (split->graph.nodes[0]->op == GGML_OP_REDUCE) { + barrier.arrive_and_wait(); + } + //if (needs_barrier) { + // barrier.arrive_and_wait(); + //} + + // record the event of this copy + if (split->n_inputs > 0) { + if (sched->events[split_backend_id][sched->cur_copy] != NULL) { + ggml_backend_event_record(sched->events[split_backend_id][sched->cur_copy]); + } + } + } + }; + + for (int i = 0; i < sched->n_backends; ++i) sched->workers.emplace_back(compute, i); + for (auto & w : sched->workers) w.join(); + sched->workers.clear(); + } + for (auto status : sched->statuses) { + if (status != GGML_STATUS_SUCCESS) return status; + } + return GGML_STATUS_SUCCESS; + } struct ggml_backend_sched_split * splits = sched->splits; @@ -2183,63 +2318,20 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s ggml_backend_t split_backend = sched->backends[split_backend_id]; // copy the input tensors to the split backend - ggml_backend_sched_copy_inputs(sched, split, needs_sync, ids, unique_ids, last_ids_tensor); + ggml_backend_sched_copy_inputs(sched, split, sched->needs_sync, ids, unique_ids, last_ids_tensor); - if (split->n_inputs > 0 && !own_cpy[split_backend_id]) { - needs_sync[split_backend_id] = true; + if (split->n_inputs > 0 && !sched->own_cpy[split_backend_id]) { + sched->needs_sync[split_backend_id] = true; } else { for (int j = 0; j < split->n_inputs; ++j) { if (ggml_backend_buffer_is_host(split->inputs[j]->buffer)) { - needs_sync[split_backend_id] = true; + sched->needs_sync[split_backend_id] = true; } } } - if (!sched->callback_eval) { -#if IK_PRINT_TIMING - int64_t tim2 = ggml_time_us(); - printf("%s(.1.): %d us\n", __func__, (int)(tim2-tim1)); -#endif - enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph); - if (ec != GGML_STATUS_SUCCESS) { - return ec; - } - } else { - // similar to ggml_backend_compare_graph_backend - for (int j0 = 0; j0 < split->graph.n_nodes; j0++) { - struct ggml_tensor * t = split->graph.nodes[j0]; - - // check if the user needs data from this node - bool need = sched->callback_eval(t, true, sched->callback_eval_user_data); - - int j1 = j0; - - // determine the range [j0, j1] of nodes that can be computed together - while (!need && j1 < split->graph.n_nodes - 1) { - t = split->graph.nodes[++j1]; - need = sched->callback_eval(t, true, sched->callback_eval_user_data); - } - - struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1); - -#if IK_PRINT_TIMING - int64_t tim2 = ggml_time_us(); - printf("%s(.2.): %d us\n", __func__, (int)(tim2-tim1)); -#endif - - enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &gv); - if (ec != GGML_STATUS_SUCCESS) { - return ec; - } - - // TODO: pass backend to the callback, then the user can decide if they want to synchronize - ggml_backend_synchronize(split_backend); - - if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) { - break; - } - - j0 = j1; - } + auto ec = ggml_backend_sched_eval(sched, split_backend, split); + if (ec != GGML_STATUS_SUCCESS) { + return ec; } // record the event of this copy @@ -2305,6 +2397,10 @@ ggml_backend_sched_t ggml_backend_sched_new( sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends); + sched->workers.reserve(sched->n_backends); + sched->statuses.resize(sched->n_backends, GGML_STATUS_SUCCESS); + sched->backend_splits.resize(sched->n_backends); + ggml_backend_sched_reset(sched); return sched; @@ -2366,15 +2462,101 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * return true; } +static void ggml_sched_prepare_graph(ggml_backend_sched_t sched) { + + for (auto & item : sched->own_cpy ) item = false; + for (auto & item : sched->needs_sync) item = true; + + if (sched->split_mode_graph) { + auto tensor_size = [] (const ggml_tensor * t) { + auto nbytes = ggml_nbytes(t); + nbytes = 256*((nbytes + 255)/256); + return nbytes; + }; + //auto tim1 = std::chrono::steady_clock::now(); + for (auto & b : sched->backend_splits) b.clear(); + for (int i = 0; i < sched->n_splits; i++) { + sched->backend_splits[sched->splits[i].backend_id].push_back(&sched->splits[i]); + } + for (int backend_id = 0; backend_id < sched->n_backends; ++backend_id) { + if (ggml_backend_is_cpu(ggml_backend_sched_get_backend(sched, backend_id))) continue; + if (sched->backend_splits[backend_id].empty()) continue; + size_t input_size = 0; + size_t max_input_size = 0; + int last_split = 0; + bool can_alloc = true; + for (int i = 0; i < int(sched->backend_splits[backend_id].size()); ++i) { + auto split = sched->backend_splits[backend_id][i]; + if (split->n_inputs < 1) continue; + size_t this_size = 0; + for (int j = 0; j < split->n_inputs; ++j) { + if (!ggml_backend_buffer_is_host(split->inputs[j]->buffer)) { + this_size += tensor_size(split->inputs[j]); + } + } + if (input_size + this_size > sched->max_extra_alloc) { + if (i - last_split < 3) { + can_alloc = false; + break; + } + max_input_size = std::max(max_input_size, input_size); + input_size = 0; + last_split = i - 1; + } + input_size += this_size; + } + max_input_size = std::max(max_input_size, input_size); + if (!can_alloc || !max_input_size) continue; + if (sched->input_memory_bufs[backend_id] && sched->input_memory_bufs[backend_id]->size < max_input_size) { + ggml_backend_buffer_free(sched->input_memory_bufs[backend_id]); + sched->input_memory_bufs[backend_id] = nullptr; + } + if (!sched->input_memory_bufs[backend_id]) { + sched->input_memory_bufs[backend_id] = ggml_backend_alloc_buffer(sched->backends[backend_id], max_input_size); + } + auto ptr = (char *)ggml_backend_buffer_get_base(sched->input_memory_bufs[backend_id]); + input_size = 0; + for (int i = 0; i < int(sched->backend_splits[backend_id].size()); ++i) { + auto split = sched->backend_splits[backend_id][i]; + size_t this_size = 0; + for (int j = 0; j < split->n_inputs; ++j) { + if (!ggml_backend_buffer_is_host(split->inputs[j]->buffer)) { + this_size += tensor_size(split->inputs[j]); + } + } + if (input_size + this_size > max_input_size) { + ptr = (char *)ggml_backend_buffer_get_base(sched->input_memory_bufs[backend_id]); + input_size = 0; + } + for (int j = 0; j < split->n_inputs; ++j) { + if (ggml_backend_buffer_is_host(split->inputs[j]->buffer)) continue; + auto input_cpy = tensor_copy(split->inputs[j], backend_id, sched->cur_copy); + for (int k = 0; k < split->graph.n_nodes; ++k) { + auto node = split->graph.nodes[k]; + for (int l = 0; l < GGML_MAX_SRC; ++l) { + if (node->src[l] && node->src[l]->data == input_cpy->data) node->src[l]->data = ptr; + } + } + input_cpy->data = ptr; + ptr += tensor_size(split->inputs[j]); + } + input_size += this_size; + } + sched->needs_sync[backend_id] = false; + sched->own_cpy[backend_id] = true; + } + } +} + bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs); ggml_backend_sched_split_graph(sched, graph); - if (!ggml_backend_sched_alloc_splits(sched)) { return false; } + ggml_sched_prepare_graph(sched); sched->is_alloc = true; diff --git a/include/llama.h b/include/llama.h index a33ce538..16558c5b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -445,6 +445,7 @@ extern "C" { bool k_cache_hadamard; // if true, apply Hadamard transfrom to K-cache bool split_mode_graph_scheduling; // if true, force split mode graph scheduling bool split_mode_f16; // if true, cast intermediate results to f16 before copying to other GPUs + bool scheduler_async; // if true, with split mode "graph" graph evaluation will be done using multiple threads // Abort callback // if it returns true, execution of llama_decode() will be aborted diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index cb7edac3..7173dd42 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -678,9 +678,6 @@ ggml_tensor * llm_build_context::llm_build_ffn( auto norm = (ggml_split_tensor_t *)ffn_norm->extra; GGML_ASSERT(norm->splits[id]); if (is_norm) { - //cur = llm_build_norm(ctx, cur, lctx.model.hparams, norm->splits[id], NULL, LLM_NORM, cb, il); - //GGML_ASSERT(cur->src[0]->op == GGML_OP_NORM); - //cur->src[0]->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff; cur = ggml_fused_norm(ctx, cur, norm->splits[id], lctx.model.hparams.f_norm_eps); } else { cur = llm_build_norm(ctx, cur, lctx.model.hparams, norm->splits[id], NULL, LLM_NORM_RMS, cb, il); @@ -9389,9 +9386,6 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens auto cur = get_input_tensor_sm_graph(input, id); if (attn_norm) { if (is_norm) { - //cur = llm_build_norm(ctx0, cur, lctx.model.hparams, attn_norm->splits[id], NULL, LLM_NORM, cb, il); - //GGML_ASSERT(cur->src[0]->op == GGML_OP_NORM); - //cur->src[0]->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff; cur = ggml_fused_norm(ctx0, cur, attn_norm->splits[id], lctx.model.hparams.f_norm_eps); } else { cur = llm_build_norm(ctx0, cur, lctx.model.hparams, attn_norm->splits[id], NULL, LLM_NORM_RMS, cb, il); diff --git a/src/llama-cparams.h b/src/llama-cparams.h index b639e818..3735a474 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -42,6 +42,7 @@ struct llama_cparams { bool k_cache_hadamard; bool split_mode_graph_scheduling; bool split_mode_f16; + bool scheduler_async; int min_experts; float thresh_experts; diff --git a/src/llama.cpp b/src/llama.cpp index 94242f0c..468c5c40 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4056,6 +4056,7 @@ struct llama_context_params llama_context_default_params() { /*.k_cache_hadamard =*/ false, /*.split_mode_graph_scheduling =*/ false, /*.split_mode_f16 =*/ true, + /*.scheduler_async =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, /*.offload_policy =*/ nullptr, @@ -4346,6 +4347,7 @@ struct llama_context * llama_new_context_with_model( cparams.k_cache_hadamard = params.k_cache_hadamard; cparams.split_mode_graph_scheduling = params.split_mode_graph_scheduling; cparams.split_mode_f16 = params.split_mode_f16; + cparams.scheduler_async = params.scheduler_async; cparams.min_experts = params.min_experts; cparams.thresh_experts = params.thresh_experts; cparams.cuda_params = params.cuda_params; @@ -4436,6 +4438,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: k_cache_hadam = %d\n", __func__, cparams.k_cache_hadamard); LLAMA_LOG_INFO("%s: split_mode_graph_scheduling = %d\n", __func__, cparams.split_mode_graph_scheduling); LLAMA_LOG_INFO("%s: split_mode_f16= %d\n", __func__, cparams.split_mode_f16); + LLAMA_LOG_INFO("%s: sched_async = %d\n", __func__, cparams.scheduler_async); LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -4780,13 +4783,13 @@ struct llama_context * llama_new_context_with_model( ggml_backend_sched_set_only_active_experts(ctx->sched, true); } if (model->split_mode == LLAMA_SPLIT_MODE_GRAPH && (!model->has_tensor_overrides() || cparams.split_mode_graph_scheduling)) { - ggml_backend_sched_set_split_mode_graph(ctx->sched, true); + ggml_backend_sched_set_split_mode_graph(ctx->sched, true, cparams.scheduler_async); ggml_backend_sched_set_max_extra_alloc(ctx->sched, params.max_extra_alloc); if (model->has_tensor_overrides() && cparams.split_mode_graph_scheduling) { LLAMA_LOG_INFO("XXXXXXXX Split Mode Graph Scheduling is FORCED despite tensor overrides due to user choice.\n"); LLAMA_LOG_INFO("XXXXXXXX It may or might NOT infer properly due to unsupported combinations between SMGS and every possible tensor overrides.\n"); } - } + } return ctx; } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 89d0ec26..d0334217 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -129,7 +129,7 @@ if (NOT WIN32) endif() llama_target_and_test(test-chat-parser.cpp) -llama_target_and_test(test-chat-template.cpp) +#llama_target_and_test(test-chat-template.cpp) llama_target_and_test(test-json-partial.cpp) llama_target_and_test(test-regex-partial.cpp)