diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index d3868f2e..605b68ce 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -50,6 +50,7 @@ #include #include #include +#include #include #include #include @@ -78,6 +79,7 @@ GGML_API void ggml_backend_cuda_log_set_callback(ggml_log_callback log_callback, #define GGML_CUDA_LOG_INFO(...) ggml_cuda_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__) #define GGML_CUDA_LOG_WARN(...) ggml_cuda_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__) #define GGML_CUDA_LOG_ERROR(...) ggml_cuda_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) +#define GGML_CUDA_LOG_DEBUG(...) ggml_cuda_log(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) GGML_ATTRIBUTE_FORMAT(2, 3) static void ggml_cuda_log(enum ggml_log_level level, const char * format, ...) { @@ -445,6 +447,70 @@ std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(i return std::unique_ptr(new ggml_cuda_pool_leg(device)); } +struct ggml_graph_node_properties { + void * node_address; + ggml_op node_op; + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; + void * src_address[GGML_MAX_SRC]; + int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; +}; + +struct ggml_cuda_graph { +#ifdef USE_CUDA_GRAPH + ~ggml_cuda_graph() { + if (instance != nullptr) { + CUDA_CHECK(cudaGraphExecDestroy(instance)); + } + if (graph != nullptr) { + CUDA_CHECK(cudaGraphDestroy(graph)); + } + } + cudaGraph_t graph = nullptr; + cudaGraphExec_t instance = nullptr; + size_t num_nodes = 0; + std::vector nodes; + std::vector params; + bool disable_due_to_gpu_arch = false; + bool disable_due_to_too_many_updates = false; + bool disable_due_to_failed_graph_capture = false; + int number_consecutive_updates = 0; + std::vector ggml_graph_properties; + bool use_cpy_indirection = false; + std::vector cpy_dest_ptrs; + char ** dest_ptrs_d; + int dest_ptrs_size = 0; + // Index to allow each cpy kernel to be aware of it's position within the graph + // relative to other cpy nodes. + int graph_cpynode_index = -1; +#endif +}; + +static std::mutex ggml_cuda_lock; +static std::condition_variable ggml_cuda_lock_cv; +static std::atomic ggml_cuda_lock_counter; + +ggml_backend_cuda_context::ggml_backend_cuda_context(int device) : + device(device), name(GGML_CUDA_NAME + std::to_string(device)) { +} + +ggml_backend_cuda_context::~ggml_backend_cuda_context() { + if (copy_event != nullptr) { + CUDA_CHECK(cudaEventDestroy(copy_event)); + } + for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) { + for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) { + if (streams[i][j] != nullptr) { + CUDA_CHECK(cudaStreamDestroy(streams[i][j])); + } + } + if (cublas_handles[i] != nullptr) { + CUBLAS_CHECK(cublasDestroy(cublas_handles[i])); + } + } + +} + // cuda buffer struct ggml_backend_cuda_buffer_context { @@ -2393,7 +2459,7 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst_row.nb[2] = nb1; dst_row.nb[3] = nb1; - if (ne12 == 1) { + if (false && ne12 == 1) { std::vector ids_host(ggml_nbytes(ids)); const char * ids_dev = (const char *) ids->data; CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); @@ -2913,6 +2979,29 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor return fuse_down; } +static void ggml_cuda_cpy_wrapper(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) { + auto src0 = dst->src[0]; + auto src1 = dst->src[1]; + if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + CUDA_CHECK(cudaMemcpyAsync((char *)src1->data, (char *)src0->data, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, ctx.stream())); + return; + } +#ifdef USE_CUDA_GRAPH + if (ctx.cuda_graph->use_cpy_indirection) { + GGML_ASSERT(ctx.cuda_graph->graph_cpynode_index < (int)ctx.cuda_graph->cpy_dest_ptrs.size()); + auto dest_ptr = ctx.cuda_graph->cpy_dest_ptrs[ctx.cuda_graph->graph_cpynode_index]; + ggml_tensor aux_src1 = *src1; + aux_src1.data = dest_ptr; + ggml_cuda_cpy(ctx, src0, &aux_src1); + ++ctx.cuda_graph->graph_cpynode_index; + } else { + ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]); + } +#else + ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]); +#endif +} + static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst, struct ggml_tensor * next, bool& skip_next) { // why is this here instead of mul_mat? if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) { @@ -2934,7 +3023,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_dup(ctx, dst); break; case GGML_OP_CPY: - ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]); + ggml_cuda_cpy_wrapper(ctx, dst); + //ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]); break; case GGML_OP_CONT: ggml_cuda_dup(ctx, dst); @@ -3216,6 +3306,102 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { GGML_UNUSED(backend); } +#ifdef USE_CUDA_GRAPH +static void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, + const int host_dest_ptrs_size, cudaStream_t stream) { + if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers + CUDA_CHECK(cudaStreamSynchronize(stream)); + if (cuda_graph->dest_ptrs_d != nullptr) { + CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d)); + } + CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *))); + cuda_graph->dest_ptrs_size = host_dest_ptrs_size; + } + // copy destination pointers to GPU + CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream)); + cuda_graph->graph_cpynode_index = 0; // reset index +} + +static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, + bool use_cuda_graph) { + + // Loop over nodes in GGML graph to obtain info needed for CUDA graph + cuda_ctx->cuda_graph->cpy_dest_ptrs.clear(); + + const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected"; + const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj"; + const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased"; + const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased"; + const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased"; + + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + continue; + } + + if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) { + use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture +#ifndef NDEBUG + GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__); +#endif + } + + if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) { + use_cuda_graph = false; // This node type is not supported by CUDA graph capture +#ifndef NDEBUG + GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__); +#endif + } + + if (node->op == GGML_OP_ADD && + node->src[1] && node->src[1]->ne[1] > 1 && + (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && + (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) && + strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 && + strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 && + strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) { + // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation + // by means of matching node names. See + // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and + // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773, + // Generally, changes in batch size or context size can cause changes to the grid size of some kernels. + use_cuda_graph = false; +#ifndef NDEBUG + GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); +#endif + } + + if (node->op == GGML_OP_CPY) { + + // Store the pointers which are updated for each token, such that these can be sent + // to the device and accessed using indirection from CUDA graph + cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data); + + // store a pointer to each copy op CUDA kernel to identify it later + void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); + if (!ptr) { + use_cuda_graph = false; +#ifndef NDEBUG + GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__); +#endif + } + } + if (!use_cuda_graph) { + break; + } + } + + if (use_cuda_graph) { + cuda_ctx->cuda_graph->use_cpy_indirection = true; + // copy pointers to GPU so they can be accessed via indirection within CUDA graph + ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream()); + } + + return use_cuda_graph; +} + static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { graph_node_properties->node_address = node->data; graph_node_properties->node_op = node->op; @@ -3226,6 +3412,7 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p for (int i = 0; i < GGML_MAX_SRC; i++) { graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; } + memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS); } static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { @@ -3257,9 +3444,246 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra return false; } } + + if (node->op == GGML_OP_SCALE && + memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { + return false; + } + return true; } +static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { + + bool cuda_graph_update_required = false; + + if (cuda_ctx->cuda_graph->instance == nullptr) { + cuda_graph_update_required = true; + } + + // Check if the graph size has changed + if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { + cuda_graph_update_required = true; + cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes); + } + + // Loop over nodes in GGML graph to determine if CUDA graph update is required + // and store properties to allow this comparison for the next token + for (int i = 0; i < cgraph->n_nodes; i++) { + bool has_matching_properties = true; + if (!cuda_graph_update_required) { + has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + } + if (!has_matching_properties) { + cuda_graph_update_required = true; + } + set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + } + + return cuda_graph_update_required; +} + +static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { + +#if CUDART_VERSION >= 12000 + cudaGraphExecUpdateResultInfo result_info; + cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); +#else + cudaGraphNode_t errorNode; + cudaGraphExecUpdateResult result_info; + cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info); +#endif // CUDART_VERSION >= 12000 + + if (stat == cudaErrorGraphExecUpdateFailure) { +#ifndef NDEBUG + GGML_CUDA_LOG_DEBUG("%s: CUDA graph update failed\n", __func__); +#endif + + // The pre-existing graph exec cannot be updated due to violated constraints + // so instead clear error and re-instantiate + (void)cudaGetLastError(); + CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance)); + cuda_ctx->cuda_graph->instance = nullptr; + CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + } else { + GGML_ASSERT(stat == cudaSuccess); + } +} +#endif + +static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, + bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { + // flag used to determine whether it is an integrated_gpu + // TODO + const bool integrated = false; //ggml_cuda_info().devices[cuda_ctx->device].integrated; + + while (!graph_evaluated_or_captured) { + // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. + // With the use of CUDA graphs, the execution will be performed by the graph launch. + if (!use_cuda_graph || cuda_graph_update_required) { + + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + ggml_tensor * next = i < cgraph->n_nodes-1 ? cgraph->nodes[i+1] : nullptr; + + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + continue; + } + +#if 0 + static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); + if (!disable_fusion) { + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) { + ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]); + i++; + continue; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { + i += 2; + ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node); + continue; + } + } +#endif +#ifndef NDEBUG + assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (node->src[j] != nullptr) { + assert(node->src[j]->buffer); + //assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || + // ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft))); + } + } +#else + GGML_UNUSED(integrated); +#endif // NDEBUG + + bool skip_next = false; + bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, next, skip_next); + if (!ok) { + GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + } + GGML_ASSERT(ok); + if (skip_next) ++i; + } + } +#ifdef USE_CUDA_GRAPH + if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture + if (cuda_ctx->cuda_graph->graph != nullptr) { + CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph)); + cuda_ctx->cuda_graph->graph = nullptr; + } + + CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph)); + graph_evaluated_or_captured = true; // CUDA graph has been captured + + std::lock_guard lock(ggml_cuda_lock); + if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) { + ggml_cuda_lock_cv.notify_all(); + } + } else { + graph_evaluated_or_captured = true; // ggml graph has been directly evaluated + } + } + + if (use_cuda_graph) { + if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph. + CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + } + if (cuda_graph_update_required) { // Update graph executable + update_cuda_graph_executable(cuda_ctx); + } + // Launch graph + CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); +#else + graph_evaluated_or_captured = true; +#endif // USE_CUDA_GRAPH + } +} + +GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + + ggml_cuda_set_device(cuda_ctx->device); + +#ifdef USE_CUDA_GRAPH + static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); + + // Objects required for CUDA Graph + if (cuda_ctx->cuda_graph == nullptr) { + cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); + } + + bool use_cuda_graph = true; + bool cuda_graph_update_required = false; + + if (cuda_ctx->cuda_graph->graph == nullptr) { + if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) { + cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; +#ifndef NDEBUG + GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); +#endif + } + } + + // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly, + // or previous graph capture failure. + // Also disable for multi-gpu for now. TO DO investigate + if (disable_cuda_graphs_due_to_env + || cuda_ctx->cuda_graph->disable_due_to_gpu_arch + || cuda_ctx->cuda_graph->disable_due_to_too_many_updates + || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) { + use_cuda_graph = false; + } + + if (use_cuda_graph) { + cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph); + + use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph); + + // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. + if (use_cuda_graph && cuda_graph_update_required) { + cuda_ctx->cuda_graph->number_consecutive_updates++; + } else { + cuda_ctx->cuda_graph->number_consecutive_updates = 0; + } + + if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) { + cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true; +#ifndef NDEBUG + GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); +#endif + } + } + + if (use_cuda_graph && cuda_graph_update_required) { + // Start CUDA graph capture + { + std::lock_guard lock(ggml_cuda_lock); + ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed); + } + + CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); + } + + if (!use_cuda_graph) { + cuda_ctx->cuda_graph->use_cpy_indirection = false; + } + +#else + bool use_cuda_graph = false; + bool cuda_graph_update_required = false; +#endif // USE_CUDA_GRAPH + + bool graph_evaluated_or_captured = false; + + evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required); + + return GGML_STATUS_SUCCESS; +} + +/* GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; @@ -3528,6 +3952,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t return GGML_STATUS_SUCCESS; } +*/ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 55088c33..b24f7fba 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -825,37 +825,7 @@ struct ggml_tensor_extra_gpu { #define USE_CUDA_GRAPH #endif -struct ggml_graph_node_properties { - void * node_address; - ggml_op node_op; - int64_t ne[GGML_MAX_DIMS]; - size_t nb[GGML_MAX_DIMS]; - void * src_address[GGML_MAX_SRC]; -}; - -struct ggml_cuda_graph { -#ifdef USE_CUDA_GRAPH - ~ggml_cuda_graph() { - if (instance != nullptr) { - CUDA_CHECK(cudaGraphExecDestroy(instance)); - } - if (graph != nullptr) { - CUDA_CHECK(cudaGraphDestroy(graph)); - } - } - cudaGraph_t graph = nullptr; - cudaGraphExec_t instance = nullptr; - size_t num_nodes = 0; - std::vector nodes; - std::vector params; - bool disable_due_to_gpu_arch = false; - bool disable_due_to_too_many_updates = false; - bool disable_due_to_failed_graph_capture = false; - int number_consecutive_updates = 0; - std::vector ggml_graph_properties; - std::vector updated_kernel_arg; -#endif -}; +struct ggml_cuda_graph; struct ggml_backend_cuda_context { int device; @@ -867,26 +837,9 @@ struct ggml_backend_cuda_context { std::unique_ptr cuda_graph; - explicit ggml_backend_cuda_context(int device) : - device(device), - name(GGML_CUDA_NAME + std::to_string(device)) { - } + explicit ggml_backend_cuda_context(int device); - ~ggml_backend_cuda_context() { - if (copy_event != nullptr) { - CUDA_CHECK(cudaEventDestroy(copy_event)); - } - for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) { - for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) { - if (streams[i][j] != nullptr) { - CUDA_CHECK(cudaStreamDestroy(streams[i][j])); - } - } - if (cublas_handles[i] != nullptr) { - CUBLAS_CHECK(cublasDestroy(cublas_handles[i])); - } - } - } + ~ggml_backend_cuda_context(); cudaStream_t stream(int device, int stream) { if (streams[device][stream] == nullptr) {