Offload only activated experts to the GPU (#698)

* Offload only activated experts

* This seems to do the trick for -fmoe

* Do not recalculate activated expers for fused up/gate

* Log out of bounds access details

* Add a command line argument

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-09-04 12:22:30 +02:00
committed by GitHub
parent 144d456717
commit 13c3b6412e
8 changed files with 155 additions and 45 deletions

View File

@@ -210,6 +210,7 @@ extern "C" {
// enable or disable op offload for a given op
GGML_API void ggml_backend_sched_set_op_offload(ggml_backend_sched_t sched, enum ggml_op op, bool on_or_off);
GGML_API void ggml_backend_sched_set_only_active_experts(ggml_backend_sched_t sched, bool on_or_off);
//
// Utils

View File

@@ -1493,7 +1493,7 @@ add_library(ggml
../include/ggml-backend.h
ggml.c
ggml-alloc.c
ggml-backend.c
ggml-backend.cpp
ggml-quants.c
ggml-quants.h
${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA}

View File

@@ -3,12 +3,14 @@
#include "ggml-impl.h"
#include "ggml-rpc.h"
#include <assert.h>
#include <limits.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <cassert>
#include <climits>
#include <cstdarg>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <vector>
#include <set>
#define IK_PRINT_TIMING 0
@@ -60,9 +62,7 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_buffer_init(
struct ggml_backend_buffer_i iface,
ggml_backend_buffer_context_t context,
size_t size) {
ggml_backend_buffer_t buffer = malloc(sizeof(struct ggml_backend_buffer));
(*buffer) = (struct ggml_backend_buffer) {
ggml_backend_buffer_t buffer = new ggml_backend_buffer {
/* .interface = */ iface,
/* .buft = */ buft,
/* .context = */ context,
@@ -200,6 +200,7 @@ size_t ggml_backend_get_max_size(ggml_backend_t backend) {
void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
if (offset + size > ggml_nbytes(tensor)) fprintf(stderr, "%s(%s): offset = %zu, size = %zu, nbytes = %zu\n", __func__, tensor->name, offset, size, ggml_nbytes(tensor));
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
if (backend->iface.set_tensor_async == NULL) {
@@ -442,6 +443,29 @@ static size_t ggml_backend_registry_count = 0;
GGML_CALL static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data);
#ifdef GGML_USE_CUDA
extern "C" GGML_CALL void ggml_backend_cuda_reg_devices(void);
#endif
#ifdef GGML_USE_SYCL
extern "C" void ggml_backend_sycl_reg_devices(void);
#endif
#ifdef GGML_USE_METAL
extern "C" GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data);
extern "C" GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
#endif
#ifdef GGML_USE_VULKAN
extern "C" GGML_CALL int ggml_backend_vk_reg_devices(void);
#endif
#ifdef GGML_USE_KOMPUTE
extern "C" GGML_CALL void ggml_backend_kompute_reg_devices(void);
#endif
#ifdef GGML_USE_CANN
extern "C" GGML_CALL int ggml_backend_cann_reg_devices(void);
#endif
#ifdef GGML_USE_RPC
extern "C" GGML_CALL void ggml_backend_rpc_reg_devices(void);
#endif
GGML_CALL static void ggml_backend_registry_init(void) {
static bool initialized = false;
@@ -455,37 +479,29 @@ GGML_CALL static void ggml_backend_registry_init(void) {
// add forward decls here to avoid including the backend headers
#ifdef GGML_USE_CUDA
extern GGML_CALL void ggml_backend_cuda_reg_devices(void);
ggml_backend_cuda_reg_devices();
#endif
#ifdef GGML_USE_SYCL
extern void ggml_backend_sycl_reg_devices(void);
ggml_backend_sycl_reg_devices();
#endif
#ifdef GGML_USE_METAL
extern GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data);
extern GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
ggml_backend_register("Metal", ggml_backend_reg_metal_init, ggml_backend_metal_buffer_type(), NULL);
#endif
#ifdef GGML_USE_VULKAN
extern GGML_CALL int ggml_backend_vk_reg_devices(void);
ggml_backend_vk_reg_devices();
#endif
#ifdef GGML_USE_KOMPUTE
extern GGML_CALL void ggml_backend_kompute_reg_devices(void);
ggml_backend_kompute_reg_devices();
#endif
#ifdef GGML_USE_CANN
extern GGML_CALL int ggml_backend_cann_reg_devices(void);
ggml_backend_cann_reg_devices();
#endif
#ifdef GGML_USE_RPC
extern GGML_CALL void ggml_backend_rpc_reg_devices(void);
ggml_backend_rpc_reg_devices();
#endif
}
@@ -495,11 +511,11 @@ GGML_CALL void ggml_backend_register(const char * name, ggml_backend_init_fn ini
size_t id = ggml_backend_registry_count;
ggml_backend_registry[id] = (struct ggml_backend_reg) {
ggml_backend_registry[id] = ggml_backend_reg {
/* .name = */ {0},
/* .fn = */ init_fn,
/* .default_buffer_type = */ default_buffer_type,
/* .user_data = */ user_data,
/* .user_data = */ user_data
};
snprintf(ggml_backend_registry[id].name, sizeof(ggml_backend_registry[id].name), "%s", name);
@@ -804,13 +820,13 @@ struct ggml_backend_plan_cpu {
GGML_CALL static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) {
struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
struct ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct ggml_backend_plan_cpu));
struct ggml_backend_plan_cpu * cpu_plan = (ggml_backend_plan_cpu *)malloc(sizeof(struct ggml_backend_plan_cpu));
cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads);
cpu_plan->cgraph = *cgraph; // FIXME: deep copy
if (cpu_plan->cplan.work_size > 0) {
cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size);
cpu_plan->cplan.work_data = (uint8_t *)malloc(cpu_plan->cplan.work_size);
if (cpu_plan->cplan.work_data == NULL) {
free(cpu_plan);
return NULL;
@@ -854,7 +870,7 @@ GGML_CALL static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t
}
cpu_ctx->work_size = cplan.work_size;
}
cplan.work_data = cpu_ctx->work_data;
cplan.work_data = (uint8_t *)cpu_ctx->work_data;
cplan.abort_callback = cpu_ctx->abort_callback;
cplan.abort_callback_data = cpu_ctx->abort_callback_data;
@@ -915,7 +931,7 @@ static ggml_guid_t ggml_backend_cpu_guid(void) {
}
ggml_backend_t ggml_backend_cpu_init(void) {
struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context));
struct ggml_backend_cpu_context * ctx = (ggml_backend_cpu_context *)malloc(sizeof(struct ggml_backend_cpu_context));
if (ctx == NULL) {
return NULL;
}
@@ -926,13 +942,13 @@ ggml_backend_t ggml_backend_cpu_init(void) {
ctx->abort_callback = NULL;
ctx->abort_callback_data = NULL;
ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend));
ggml_backend_t cpu_backend = (ggml_backend_t)malloc(sizeof(struct ggml_backend));
if (cpu_backend == NULL) {
free(ctx);
return NULL;
}
*cpu_backend = (struct ggml_backend) {
*cpu_backend = ggml_backend {
/* .guid = */ ggml_backend_cpu_guid(),
/* .interface = */ cpu_backend_i,
/* .context = */ ctx
@@ -1144,6 +1160,7 @@ struct ggml_backend_sched {
uint32_t op_offload[(GGML_OP_COUNT + 31)/32];
bool only_active_experts;
bool debug;
};
@@ -1164,6 +1181,11 @@ void ggml_backend_sched_set_op_offload(ggml_backend_sched_t sched, enum ggml_op
}
}
void ggml_backend_sched_set_only_active_experts(ggml_backend_sched_t sched, bool on_or_off) {
if (!sched) return;
sched->only_active_experts = on_or_off;
}
static inline bool ggml_backend_sched_offload_enabled(ggml_backend_sched_t sched, enum ggml_op op) {
int int_op = (int)op;
if (!sched || op < 0 || op >= GGML_OP_COUNT) return false;
@@ -1630,7 +1652,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
i_split++;
if (i_split >= sched->splits_capacity) {
sched->splits_capacity *= 2;
sched->splits = realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split));
sched->splits = (ggml_backend_sched_split *)realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split));
GGML_ASSERT(sched->splits != NULL);
}
GGML_ASSERT(i_split < GGML_SCHED_MAX_SPLITS);
@@ -1720,8 +1742,8 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
int graph_size = graph->n_nodes + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2;
if (sched->graph.size < graph_size) {
sched->graph.size = graph_size;
sched->graph.nodes = realloc(sched->graph.nodes, graph_size * sizeof(struct ggml_tensor *));
sched->graph.leafs = realloc(sched->graph.leafs, graph_size * sizeof(struct ggml_tensor *));
sched->graph.nodes = (ggml_tensor **)realloc(sched->graph.nodes, graph_size * sizeof(struct ggml_tensor *));
sched->graph.leafs = (ggml_tensor **)realloc(sched->graph.leafs, graph_size * sizeof(struct ggml_tensor *));
GGML_ASSERT(sched->graph.nodes != NULL);
GGML_ASSERT(sched->graph.leafs != NULL);
}
@@ -1844,6 +1866,14 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
int split_backend_id = split->backend_id;
ggml_backend_t split_backend = sched->backends[split_backend_id];
int cur_arg = 0;
std::vector<int32_t> ids;
std::set<int32_t> unique_ids;
//printf("Graph split %d has %d inputs:\n", i, split->n_inputs);
//for (int j = 0; j < split->n_inputs; j++) printf(" %s, %s\n", split->inputs[j]->name,
// split->inputs[j]->src[0] ? split->inputs[j]->src[0]->name : "none");
// copy the input tensors to the split backend
for (int j = 0; j < split->n_inputs; j++) {
ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[j]);
@@ -1865,6 +1895,71 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
} else {
ggml_backend_synchronize(split_backend);
}
ggml_tensor * node = split->graph.nodes[0];
if (sched->only_active_experts && split->graph.n_nodes > 0 &&
ggml_backend_buffer_get_usage(input->buffer) == GGML_BACKEND_BUFFER_USAGE_WEIGHTS &&
ggml_backend_buffer_is_host(input->buffer) &&
node->src[cur_arg] == input_cpy &&
(node->op == GGML_OP_MUL_MAT_ID || node->op == GGML_OP_MOE_FUSED_UP_GATE)) {
if (ids.empty()) {
// find the ids
ggml_tensor * ids_tensor = node->op == GGML_OP_MUL_MAT_ID ? node->src[2] : node->src[3];
ids.resize(ggml_nbytes(ids_tensor) / sizeof(int32_t));
ggml_backend_synchronize(input_backend);
ggml_backend_tensor_get_async(split_backend, ids_tensor, ids.data(), 0, ggml_nbytes(ids_tensor));
ggml_backend_synchronize(split_backend);
for (int64_t i1 = 0; i1 < ids_tensor->ne[1]; i1++) {
for (int64_t i0 = 0; i0 < ids_tensor->ne[0]; i0++) {
int32_t id = ids[i1 * ids_tensor->nb[1]/sizeof(int32_t) + i0 * ids_tensor->nb[0]/sizeof(int32_t)];
unique_ids.insert(id);
}
}
// group consecutive experts and copy them together
GGML_ASSERT(!unique_ids.empty());
}
auto it = unique_ids.begin();
int32_t first_id = *it;
int32_t last_id = first_id;
auto copy_experts = [&](int32_t first_id, int32_t last_id) {
const size_t expert_size = (node->op == GGML_OP_MUL_MAT_ID || node->op == GGML_OP_MOE_FUSED_UP_GATE) ? input->nb[2] : input->nb[1];
const size_t expert_offset = first_id * expert_size;
const size_t expert_size_copy = (last_id - first_id + 1) * expert_size;
const size_t padding = 512;
const size_t padding_end = last_id < input->ne[2] - 1 ? std::min<size_t>(expert_size, padding) : 0;
ggml_backend_tensor_set_async(split_backend,
input_cpy,
(const uint8_t *)input->data + expert_offset, expert_offset,
// copy a bit extra to ensure there are no NaNs in the padding
expert_size_copy + padding_end);
};
for (++it; it != unique_ids.end(); ++it) {
const int32_t id = *it;
if (id == last_id + 1) {
last_id = id;
continue;
}
copy_experts(first_id, last_id);
first_id = id;
last_id = id;
}
copy_experts(first_id, last_id);
if (node->op == GGML_OP_MOE_FUSED_UP_GATE) ++cur_arg;
} else
// try async copy, but if not possible, we can still use a sync copy without synchronizing the dst backend, since we handle the synchronization here with multiple copies and events
// TODO: add public function to facilitate this, since applications do not have direct access to the backend interface
if (!split_backend->iface.cpy_tensor_async || !split_backend->iface.cpy_tensor_async(input_backend, split_backend, input, input_cpy)) {
@@ -1950,7 +2045,7 @@ ggml_backend_sched_t ggml_backend_sched_new(
GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS);
GGML_ASSERT(ggml_backend_is_cpu(backends[n_backends - 1])); // last backend must be CPU
struct ggml_backend_sched * sched = calloc(1, sizeof(struct ggml_backend_sched));
struct ggml_backend_sched * sched = (ggml_backend_sched *)calloc(1, sizeof(struct ggml_backend_sched));
for (int i = 0; i < (GGML_OP_COUNT + 31)/32; ++i) sched->op_offload[i] = 0xffffffff;
@@ -1961,20 +2056,20 @@ ggml_backend_sched_t ggml_backend_sched_new(
// initialize hash table
// FIXME: needs to be size*2 to account for leafs (do it in graph_split instead)
sched->hash_set = ggml_hash_set_new(graph_size);
sched->hv_tensor_backend_ids = malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0]));
sched->hv_tensor_copies = malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *));
sched->hv_tensor_backend_ids = (int *)malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0]));
sched->hv_tensor_copies = (ggml_tensor **)malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *));
const size_t nodes_size = graph_size + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2;
sched->node_backend_ids = calloc(nodes_size, sizeof(sched->node_backend_ids[0]));
sched->leaf_backend_ids = calloc(nodes_size, sizeof(sched->leaf_backend_ids[0]));
sched->prev_node_backend_ids = calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0]));
sched->prev_leaf_backend_ids = calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0]));
sched->node_backend_ids = (int *)calloc(nodes_size, sizeof(sched->node_backend_ids[0]));
sched->leaf_backend_ids = (int *)calloc(nodes_size, sizeof(sched->leaf_backend_ids[0]));
sched->prev_node_backend_ids = (int *)calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0]));
sched->prev_leaf_backend_ids = (int *)calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0]));
sched->context_buffer_size = GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false);
sched->context_buffer = malloc(sched->context_buffer_size);
sched->context_buffer = (char *)malloc(sched->context_buffer_size);
const int initial_splits_capacity = 16;
sched->splits = calloc(initial_splits_capacity, sizeof(sched->splits[0]));
sched->splits = (ggml_backend_sched_split *)calloc(initial_splits_capacity, sizeof(sched->splits[0]));
sched->splits_capacity = initial_splits_capacity;
for (int b = 0; b < n_backends; b++) {
@@ -2219,8 +2314,8 @@ static void graph_copy_init_tensor(struct ggml_hash_set * hash_set, struct ggml_
struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) {
struct ggml_hash_set hash_set = ggml_hash_set_new(graph->visited_hash_set.size);
struct ggml_tensor ** node_copies = calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT
bool * node_init = calloc(hash_set.size, sizeof(node_init[0]));
struct ggml_tensor ** node_copies = (ggml_tensor **)calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT
bool * node_init = (bool *)calloc(hash_set.size, sizeof(node_init[0]));
struct ggml_init_params params = {
/* .mem_size = */ ggml_tensor_overhead()*hash_set.size + ggml_graph_overhead_custom(graph->size, false),
@@ -2238,7 +2333,7 @@ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, s
free(node_init);
ggml_free(ctx_allocated);
ggml_free(ctx_unallocated);
return (struct ggml_backend_graph_copy) {
return {
/* .buffer = */ NULL,
/* .ctx_allocated = */ NULL,
/* .ctx_unallocated = */ NULL,
@@ -2261,7 +2356,7 @@ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, s
free(node_init);
ggml_free(ctx_allocated);
ggml_free(ctx_unallocated);
return (struct ggml_backend_graph_copy) {
return {
/* .buffer = */ NULL,
/* .ctx_allocated = */ NULL,
/* .ctx_unallocated = */ NULL,
@@ -2290,7 +2385,7 @@ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, s
free(node_copies);
free(node_init);
return (struct ggml_backend_graph_copy) {
return {
/* .buffer = */ buffer,
/* .ctx_allocated = */ ctx_allocated,
/* .ctx_unallocated = */ ctx_unallocated,

View File

@@ -4288,8 +4288,9 @@ GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const
if (batch_size < min_batch_size) return false;
int64_t n_experts_tot = op->src[0]->ne[2];
int64_t n_experts_active = ids->ne[0];
//printf("%s(%s): op->ne[2] = %ld, n_experts_tot = %ld, n_experts_active = %ld, ids: %s, %ld x %ld x %ld x %ld\n", __func__, op->name, op->ne[2], n_experts_tot, n_experts_active, ids->name, ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3]);
return batch_size*n_experts_active >= min_batch_size*n_experts_tot;
bool should_offload = batch_size*n_experts_active >= min_batch_size*n_experts_tot;
//printf("%s(%s): op->ne[2] = %ld, n_experts_tot = %ld, n_experts_active = %ld, ids: %s, %ld x %ld x %ld x %ld -> %d (%ld, %ld)\n", __func__, op->name, op->ne[2], n_experts_tot, n_experts_active, ids->name, ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], should_offload, batch_size*n_experts_active, min_batch_size*n_experts_tot);
return should_offload;
}
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;