Disable pipeline parallel for tensor override or allocation failed (#879)

* disable pipeline parallelism when tensor override present

* disable pipeline parallel if allocation failed

---------

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana
2025-10-31 12:20:48 +00:00
committed by GitHub
parent 14760aaf46
commit c7dbe3f2c1
4 changed files with 28 additions and 7 deletions

View File

@@ -3969,7 +3969,7 @@ struct llama_model * llama_load_model_from_file(
return true;
};
}
model->set_tensor_overrides(params);
// model->devices hold device indices that are used to offload
// use model->devices to determine offload device
// if no device is specified, all device are included
@@ -4479,7 +4479,7 @@ struct llama_context * llama_new_context_with_model(
llama_get_device_count(*model) > 1 &&
model->n_gpu_layers > (int)model->hparams.n_layer &&
model->split_mode == LLAMA_SPLIT_MODE_LAYER &&
params.offload_kqv;
params.offload_kqv && !model->has_tensor_overrides();
#ifndef GGML_USE_CUDA
// pipeline parallelism requires support for async compute and events
// currently this is only implemented in the CUDA backend
@@ -4498,10 +4498,19 @@ struct llama_context * llama_new_context_with_model(
ggml_cgraph * gf = llm_build_context::llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
// initialize scheduler with the worst-case graph
if (!ggml_backend_sched_reserve(ctx->sched, gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
llama_free(ctx);
return nullptr;
bool gf_success = ggml_backend_sched_reserve(ctx->sched, gf);
if (!gf_success)
{
if (pipeline_parallel) {
LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__);
ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), max_nodes, false);
gf_success = ggml_backend_sched_reserve(ctx->sched, gf);
}
if (!gf_success) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
llama_free(ctx);
return nullptr;
}
}
for (size_t i = 0; i < ctx->backends.size(); i++) {