mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-29 02:41:47 +00:00
Automatically disable CUDA graphs for split mode "graph" (#1040)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -3725,7 +3725,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|||||||
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
|
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
|
||||||
// flag used to determine whether it is an integrated_gpu
|
// flag used to determine whether it is an integrated_gpu
|
||||||
// TODO
|
// TODO
|
||||||
const bool integrated = false; //ggml_cuda_info().devices[cuda_ctx->device].integrated;
|
[[maybe_unused]] const bool integrated = false; //ggml_cuda_info().devices[cuda_ctx->device].integrated;
|
||||||
|
|
||||||
//printf("======================== %s: graph with %d nodes on device %d. time = %ld\n", __func__, cgraph->n_nodes, cuda_ctx->device, ggml_time_us());
|
//printf("======================== %s: graph with %d nodes on device %d. time = %ld\n", __func__, cgraph->n_nodes, cuda_ctx->device, ggml_time_us());
|
||||||
while (!graph_evaluated_or_captured) {
|
while (!graph_evaluated_or_captured) {
|
||||||
@@ -3763,8 +3763,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|||||||
assert(node->src[j]->buffer);
|
assert(node->src[j]->buffer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
GGML_UNUSED(integrated);
|
|
||||||
#endif // NDEBUG
|
#endif // NDEBUG
|
||||||
|
|
||||||
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, cgraph, i);
|
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, cgraph, i);
|
||||||
@@ -3816,15 +3814,19 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|||||||
#ifdef USE_CUDA_GRAPH
|
#ifdef USE_CUDA_GRAPH
|
||||||
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
||||||
|
|
||||||
|
// Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
|
||||||
|
// or previous graph capture failure.
|
||||||
|
// Also disable for multi-gpu for now. TO DO investigate
|
||||||
|
bool use_cuda_graph = !disable_cuda_graphs_due_to_env && cuda_ctx->use_cuda_graph;
|
||||||
|
|
||||||
// Objects required for CUDA Graph
|
// Objects required for CUDA Graph
|
||||||
if (cuda_ctx->cuda_graph == nullptr) {
|
if (cuda_ctx->cuda_graph == nullptr) {
|
||||||
cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
|
cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool use_cuda_graph = true;
|
|
||||||
bool cuda_graph_update_required = false;
|
bool cuda_graph_update_required = false;
|
||||||
|
|
||||||
if (cuda_ctx->cuda_graph->graph == nullptr) {
|
if (use_cuda_graph && cuda_ctx->cuda_graph->graph == nullptr) {
|
||||||
if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
|
if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
|
||||||
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
|
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
@@ -3833,13 +3835,10 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
|
if (use_cuda_graph && (
|
||||||
// or previous graph capture failure.
|
cuda_ctx->cuda_graph->disable_due_to_gpu_arch ||
|
||||||
// Also disable for multi-gpu for now. TO DO investigate
|
cuda_ctx->cuda_graph->disable_due_to_too_many_updates ||
|
||||||
if (disable_cuda_graphs_due_to_env
|
cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture)) {
|
||||||
|| cuda_ctx->cuda_graph->disable_due_to_gpu_arch
|
|
||||||
|| cuda_ctx->cuda_graph->disable_due_to_too_many_updates
|
|
||||||
|| cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
|
|
||||||
use_cuda_graph = false;
|
use_cuda_graph = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4287,6 +4286,11 @@ struct cuda_params {
|
|||||||
int fusion = GGML_CUDA_FUSION;
|
int fusion = GGML_CUDA_FUSION;
|
||||||
int offload_batch_size = GGML_CUDA_MIN_BATCH_OFFLOAD;
|
int offload_batch_size = GGML_CUDA_MIN_BATCH_OFFLOAD;
|
||||||
int mmq_id_thresh = 32;
|
int mmq_id_thresh = 32;
|
||||||
|
#ifdef USE_CUDA_GRAPH
|
||||||
|
bool use_cuda_graph = true;
|
||||||
|
#else
|
||||||
|
bool use_cuda_graph = false;
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
static std::vector<std::string> string_split(const std::string& str, const std::string& delimiter) {
|
static std::vector<std::string> string_split(const std::string& str, const std::string& delimiter) {
|
||||||
@@ -4333,6 +4337,11 @@ static cuda_params ggml_cuda_parse_params(const char * params_string) {
|
|||||||
else if (parsed[0] == "mmq-id-size") {
|
else if (parsed[0] == "mmq-id-size") {
|
||||||
is_good = read_value(parsed[1], params.mmq_id_thresh);
|
is_good = read_value(parsed[1], params.mmq_id_thresh);
|
||||||
}
|
}
|
||||||
|
#ifdef USE_CUDA_GRAPH
|
||||||
|
else if (parsed[0] == "graphs") {
|
||||||
|
is_good = read_value(parsed[1], params.use_cuda_graph);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
if (!is_good) {
|
if (!is_good) {
|
||||||
GGML_CUDA_LOG_WARN("%s: invalid parameter %s (%d) -> ignored\n", __func__, value.c_str(), (int)parsed.size());
|
GGML_CUDA_LOG_WARN("%s: invalid parameter %s (%d) -> ignored\n", __func__, value.c_str(), (int)parsed.size());
|
||||||
@@ -4373,6 +4382,12 @@ GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device, [[maybe_unused]] con
|
|||||||
GGML_CUDA_LOG_INFO(" =========================== %s: setting mmq_id_thresh to %d\n", __func__, params.mmq_id_thresh);
|
GGML_CUDA_LOG_INFO(" =========================== %s: setting mmq_id_thresh to %d\n", __func__, params.mmq_id_thresh);
|
||||||
ctx->mmq_id_thresh = params.mmq_id_thresh;
|
ctx->mmq_id_thresh = params.mmq_id_thresh;
|
||||||
}
|
}
|
||||||
|
#ifdef USE_CUDA_GRAPH
|
||||||
|
if (params.use_cuda_graph != ctx->use_cuda_graph) {
|
||||||
|
GGML_CUDA_LOG_INFO(" =========================== %s: setting use_cuda_graph to %d\n", __func__, params.use_cuda_graph);
|
||||||
|
ctx->use_cuda_graph = params.use_cuda_graph;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
return cuda_backend;
|
return cuda_backend;
|
||||||
|
|||||||
@@ -840,6 +840,9 @@ struct ggml_backend_cuda_context {
|
|||||||
int fusion = GGML_CUDA_FUSION;
|
int fusion = GGML_CUDA_FUSION;
|
||||||
int offload_batch_size = GGML_CUDA_MIN_BATCH_OFFLOAD;
|
int offload_batch_size = GGML_CUDA_MIN_BATCH_OFFLOAD;
|
||||||
int mmq_id_thresh = 32;
|
int mmq_id_thresh = 32;
|
||||||
|
#ifdef USE_CUDA_GRAPH
|
||||||
|
bool use_cuda_graph = true;
|
||||||
|
#endif
|
||||||
|
|
||||||
explicit ggml_backend_cuda_context(int device);
|
explicit ggml_backend_cuda_context(int device);
|
||||||
|
|
||||||
|
|||||||
@@ -4480,8 +4480,16 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
|
|
||||||
} else {
|
} else {
|
||||||
// LLAMA_SPLIT_MODE_LAYER and LLAMA_SPLIT_MODE_GRAPH require a backend for each GPU
|
// LLAMA_SPLIT_MODE_LAYER and LLAMA_SPLIT_MODE_GRAPH require a backend for each GPU
|
||||||
|
auto params = cparams.cuda_params;
|
||||||
|
std::string new_params;
|
||||||
|
if (model->split_mode == LLAMA_SPLIT_MODE_GRAPH) {
|
||||||
|
static const std::string extra_string{"graphs=0"};
|
||||||
|
if (params) new_params = std::string{(const char *)params} + ',';
|
||||||
|
new_params += extra_string;
|
||||||
|
params = new_params.data();
|
||||||
|
}
|
||||||
for (int device = 0; device < ggml_backend_cuda_get_device_count(); ++device) {
|
for (int device = 0; device < ggml_backend_cuda_get_device_count(); ++device) {
|
||||||
ggml_backend_t backend = ggml_backend_cuda_init(device, cparams.cuda_params);
|
ggml_backend_t backend = ggml_backend_cuda_init(device, params);
|
||||||
if (backend == nullptr) {
|
if (backend == nullptr) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, device);
|
LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, device);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
|
|||||||
Reference in New Issue
Block a user