Use GUDA graphs also when theretensor overrides

This commit is contained in:
Kawrakow
2026-01-20 15:26:47 +00:00
parent 6f1a69352f
commit c307525cb2
4 changed files with 83 additions and 97 deletions

View File

@@ -3663,11 +3663,23 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
#ifdef USE_CUDA_GRAPH
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
bool use_cuda_graph) {
static inline const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
return cgraph->nodes[0]->data;
}
static inline ggml_cuda_graph * ggml_cuda_get_graph(ggml_backend_cuda_context & ctx, const void * key) {
auto & graph = ctx.cuda_graphs[key];
if (!graph) {
graph = std::make_unique<ggml_cuda_graph>();
}
return graph.get();
}
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_cuda_graph * graph, ggml_cgraph * cgraph,
bool use_cuda_graph, cudaStream_t stream) {
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
graph->cpy_dest_ptrs.clear();
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
@@ -3678,16 +3690,7 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}
if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
#ifndef NDEBUG
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer %s\n", __func__, node->src[0]->name);
#endif
}
if (ggml_is_noop(node)) continue;
if (node->op == GGML_OP_MUL_MAT_ID && (node->ne[2] != 1 || node->src[2]->ne[0] != 1)) {
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
@@ -3735,7 +3738,7 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
// Store the pointers which are updated for each token, such that these can be sent
// to the device and accessed using indirection from CUDA graph
cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
// store a pointer to each copy op CUDA kernel to identify it later
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
@@ -3752,9 +3755,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
}
if (use_cuda_graph) {
cuda_ctx->cuda_graph->use_cpy_indirection = true;
graph->use_cpy_indirection = true;
// copy pointers to GPU so they can be accessed via indirection within CUDA graph
ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream());
ggml_cuda_cpy_dest_ptrs_copy(graph, graph->cpy_dest_ptrs.data(), graph->cpy_dest_ptrs.size(), stream);
}
return use_cuda_graph;
@@ -3811,18 +3814,18 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
return true;
}
static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
static bool is_cuda_graph_update_required(ggml_cuda_graph * graph, ggml_cgraph * cgraph) {
bool cuda_graph_update_required = false;
if (cuda_ctx->cuda_graph->instance == nullptr) {
if (graph->instance == nullptr) {
cuda_graph_update_required = true;
}
// Check if the graph size has changed
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
if (graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
cuda_graph_update_required = true;
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
graph->ggml_graph_properties.resize(cgraph->n_nodes);
}
// Loop over nodes in GGML graph to determine if CUDA graph update is required
@@ -3830,26 +3833,26 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
for (int i = 0; i < cgraph->n_nodes; i++) {
bool has_matching_properties = true;
if (!cuda_graph_update_required) {
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &graph->ggml_graph_properties[i]);
}
if (!has_matching_properties) {
cuda_graph_update_required = true;
}
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
set_ggml_graph_node_properties(cgraph->nodes[i], &graph->ggml_graph_properties[i]);
}
return cuda_graph_update_required;
}
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
static void update_cuda_graph_executable(ggml_cuda_graph * graph) {
#if CUDART_VERSION >= 12000
cudaGraphExecUpdateResultInfo result_info;
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &result_info);
#else
cudaGraphNode_t errorNode;
cudaGraphExecUpdateResult result_info;
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &errorNode, &result_info);
#endif // CUDART_VERSION >= 12000
if (stat == cudaErrorGraphExecUpdateFailure) {
@@ -3860,9 +3863,9 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
// The pre-existing graph exec cannot be updated due to violated constraints
// so instead clear error and re-instantiate
(void)cudaGetLastError();
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
cuda_ctx->cuda_graph->instance = nullptr;
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
CUDA_CHECK(cudaGraphExecDestroy(graph->instance));
graph->instance = nullptr;
CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
} else {
GGML_ASSERT(stat == cudaSuccess);
}
@@ -3875,6 +3878,10 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
// TODO
[[maybe_unused]] const bool integrated = false; //ggml_cuda_info().devices[cuda_ctx->device].integrated;
#ifdef USE_CUDA_GRAPH
auto graph = ggml_cuda_get_graph(*cuda_ctx, ggml_cuda_graph_get_key(cgraph));
#endif
//printf("======================== %s: graph with %d nodes on device %d. time = %ld\n", __func__, cgraph->n_nodes, cuda_ctx->device, ggml_time_us());
while (!graph_evaluated_or_captured) {
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
@@ -3884,34 +3891,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}
#if 0
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion) {
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
i++;
continue;
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
i += 2;
ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
continue;
}
}
#endif
#ifndef NDEBUG
//assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
//for (int j = 0; j < GGML_MAX_SRC; j++) {
// if (node->src[j] != nullptr) {
// assert(node->src[j]->buffer);
// }
//}
#endif // NDEBUG
if (ggml_is_noop(node)) continue;
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, cgraph, i);
if (!ok) {
@@ -3922,12 +3902,12 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
}
#ifdef USE_CUDA_GRAPH
if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
if (cuda_ctx->cuda_graph->graph != nullptr) {
CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
cuda_ctx->cuda_graph->graph = nullptr;
if (graph->graph != nullptr) {
CUDA_CHECK(cudaGraphDestroy(graph->graph));
graph->graph = nullptr;
}
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph->graph));
graph_evaluated_or_captured = true; // CUDA graph has been captured
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
@@ -3940,14 +3920,14 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
}
if (use_cuda_graph) {
if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
if (graph->instance == nullptr) { // Create executable graph from captured graph.
CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
}
if (cuda_graph_update_required) { // Update graph executable
update_cuda_graph_executable(cuda_ctx);
update_cuda_graph_executable(graph);
}
// Launch graph
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream()));
#else
graph_evaluated_or_captured = true;
#endif // USE_CUDA_GRAPH
@@ -3960,6 +3940,8 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
ggml_cuda_set_device(cuda_ctx->device);
#ifdef USE_CUDA_GRAPH
cuda_ctx->cur_graph = nullptr;
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
// Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
@@ -3967,16 +3949,14 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
// Also disable for multi-gpu for now. TO DO investigate
bool use_cuda_graph = !disable_cuda_graphs_due_to_env && cuda_ctx->use_cuda_graph;
// Objects required for CUDA Graph
if (cuda_ctx->cuda_graph == nullptr) {
cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
}
auto graph = ggml_cuda_get_graph(*cuda_ctx, ggml_cuda_graph_get_key(cgraph));
cuda_ctx->cur_graph = graph;
bool cuda_graph_update_required = false;
if (use_cuda_graph && cuda_ctx->cuda_graph->graph == nullptr) {
if (use_cuda_graph && graph->graph == nullptr) {
if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
graph->disable_due_to_gpu_arch = true;
#ifndef NDEBUG
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
#endif
@@ -3984,26 +3964,26 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
}
if (use_cuda_graph && (
cuda_ctx->cuda_graph->disable_due_to_gpu_arch ||
cuda_ctx->cuda_graph->disable_due_to_too_many_updates ||
cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture)) {
graph->disable_due_to_gpu_arch ||
graph->disable_due_to_too_many_updates ||
graph->disable_due_to_failed_graph_capture)) {
use_cuda_graph = false;
}
if (use_cuda_graph) {
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
cuda_graph_update_required = is_cuda_graph_update_required(graph, cgraph);
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(graph, cgraph, use_cuda_graph, cuda_ctx->stream());
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
if (use_cuda_graph && cuda_graph_update_required) {
cuda_ctx->cuda_graph->number_consecutive_updates++;
graph->number_consecutive_updates++;
} else {
cuda_ctx->cuda_graph->number_consecutive_updates = 0;
graph->number_consecutive_updates = 0;
}
if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
if (graph->number_consecutive_updates >= 4) {
graph->disable_due_to_too_many_updates = true;
#ifndef NDEBUG
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
#endif
@@ -4021,7 +4001,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
}
if (!use_cuda_graph) {
cuda_ctx->cuda_graph->use_cpy_indirection = false;
graph->use_cpy_indirection = false;
}
#else

View File

@@ -25,6 +25,7 @@
#include <cfloat>
#include <string>
#include <vector>
#include <unordered_map>
#if defined(GGML_USE_HIPBLAS)
#include "vendors/hip.h"
@@ -849,13 +850,16 @@ struct ggml_backend_cuda_context {
cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
std::unique_ptr<ggml_cuda_graph> cuda_graph;
int fusion = GGML_CUDA_FUSION;
int offload_batch_size = GGML_CUDA_MIN_BATCH_OFFLOAD;
int mmq_id_thresh = 32;
#ifdef USE_CUDA_GRAPH
bool use_cuda_graph = true;
ggml_cuda_graph * cur_graph = nullptr;
std::unordered_map<const void *, std::unique_ptr<ggml_cuda_graph>> cuda_graphs;
#endif
void * copy_buffer = nullptr;

View File

@@ -542,9 +542,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
char ** dest_ptrs_d = nullptr;
int graph_cpynode_index = -1;
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if(!disable_indirection_for_this_node && ctx.cuda_graph && ctx.cuda_graph->use_cpy_indirection) {
dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
if(!disable_indirection_for_this_node && ctx.cur_graph && ctx.cur_graph->use_cpy_indirection) {
dest_ptrs_d = ctx.cur_graph->dest_ptrs_d;
graph_cpynode_index = ctx.cur_graph->graph_cpynode_index;
}
#else
GGML_UNUSED(disable_indirection_for_this_node);
@@ -651,8 +651,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_type_name(src0->type), ggml_type_name(src1->type));
}
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if(!disable_indirection_for_this_node && ctx.cuda_graph && ctx.cuda_graph->use_cpy_indirection) {
ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
if(!disable_indirection_for_this_node && ctx.cur_graph && ctx.cur_graph->use_cpy_indirection) {
ctx.cur_graph->graph_cpynode_index = graph_cpynode_index;
}
#else
GGML_UNUSED(disable_indirection_for_this_node);
@@ -796,9 +796,9 @@ bool ggml_cuda_cpy_2(ggml_backend_cuda_context & ctx, const ggml_tensor * src1,
char ** dest_ptrs = nullptr;
int graph_cpynode_index = -1;
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection) {
dest_ptrs = ctx.cuda_graph->dest_ptrs_d;
graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
if(ctx.cur_graph->use_cpy_indirection && !disable_indirection) {
dest_ptrs = ctx.cur_graph->dest_ptrs_d;
graph_cpynode_index = ctx.cur_graph->graph_cpynode_index;
}
#else
GGML_UNUSED(disable_indirection);
@@ -813,8 +813,8 @@ bool ggml_cuda_cpy_2(ggml_backend_cuda_context & ctx, const ggml_tensor * src1,
}
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection) {
ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
if(ctx.cur_graph->use_cpy_indirection && !disable_indirection) {
ctx.cur_graph->graph_cpynode_index = graph_cpynode_index;
}
#endif
return true;
@@ -859,9 +859,9 @@ bool ggml_cuda_concat_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * c
char ** dest_ptrs = nullptr;
int graph_cpynode_index = -1;
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection) {
dest_ptrs = ctx.cuda_graph->dest_ptrs_d;
graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
if(ctx.cur_graph->use_cpy_indirection && !disable_indirection) {
dest_ptrs = ctx.cur_graph->dest_ptrs_d;
graph_cpynode_index = ctx.cur_graph->graph_cpynode_index;
}
#endif
@@ -874,8 +874,8 @@ bool ggml_cuda_concat_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * c
}
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection) {
ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
if(ctx.cur_graph->use_cpy_indirection && !disable_indirection) {
ctx.cur_graph->graph_cpynode_index = graph_cpynode_index;
}
#endif
return true;

View File

@@ -1,5 +1,7 @@
#pragma once
#include "ggml.h"
struct ggml_graph_node_properties {
void * node_address;
ggml_op node_op;