WIP: graph appears to work, layer is broken

This commit is contained in:
Kawrakow
2025-12-21 14:28:14 +00:00
parent 5c9e48041e
commit 993cb00a34
5 changed files with 138 additions and 48 deletions

View File

@@ -1421,6 +1421,28 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
int * node_backend_id = &tensor_backend_id(node);
if (node->op == GGML_OP_REDUCE) {
auto view_src = node->view_src;
int src_id = -1;
for (int j = 0; j < node->op_params[1]; ++j) {
if (node->src[j]) {
int * this_node_backend_id = &tensor_backend_id(node->src[j]);
if (*this_node_backend_id == -1) {
*this_node_backend_id = j;
} else {
GGML_ASSERT(*this_node_backend_id == j);
}
if (view_src == node->src[j]) {
src_id = j;
}
}
}
if (src_id >= 0) {
int * this_node_backend_id = &tensor_backend_id(view_src);
*this_node_backend_id = tensor_backend_id(node->src[src_id]);
*node_backend_id = *this_node_backend_id;
}
}
// do not overwrite user assignments
if (*node_backend_id == -1) {
*node_backend_id = ggml_backend_sched_backend_id_from_cur(sched, node);
@@ -1652,6 +1674,8 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
// check if we should start a new split based on the sources of the current node
bool need_new_split = false;
if ((node->op == GGML_OP_ADD && node->op_params[0] == 0xff) ||
node->op == GGML_OP_REDUCE ||
node->op == GGML_OP_FAKE_CPY ||
node->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] == 0xff) {
need_new_split = true;
}
@@ -1739,6 +1763,13 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
if (src_backend_id != cur_backend_id && !ggml_backend_sched_buffer_supported(sched, src, cur_backend_id)) {
// create a copy of the input in the split's backend
if (tensor_id_copy(src_id, cur_backend_id, 0) == NULL) {
if (node->op == GGML_OP_REDUCE) {
//printf("setting tensor_id_copy(reduce, %zu, %d, %s) to %s\n", src_id, cur_backend_id, node->name, src->name);
tensor_id_copy(src_id, cur_backend_id, 0) = src;
} else if (node->op == GGML_OP_FAKE_CPY && src->op == GGML_OP_REDUCE) {
//printf("setting tensor_id_copy(fake_cpy, %zu, %d, %s) to %s\n", src_id, cur_backend_id, node->name, src->src[j]->name);
tensor_id_copy(src_id, cur_backend_id, 0) = src->src[j];
} else {
ggml_backend_t backend = sched->backends[cur_backend_id];
for (int c = 0; c < sched->n_copies; c++) {
struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
@@ -1753,6 +1784,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
int n_inputs = split->n_inputs++;
GGML_ASSERT(n_inputs < GGML_SCHED_MAX_SPLIT_INPUTS);
split->inputs[n_inputs] = src;
}
}
node->src[j] = tensor_id_copy(src_id, cur_backend_id, sched->cur_copy);
}

View File

@@ -48,6 +48,7 @@
#include "ggml-cuda/argmax.cuh"
#include "ggml-cuda/multiadd.cuh"
#include "ggml-cuda/hadamard.cuh"
#include "ggml-cuda/reduce.cuh"
#include <algorithm>
#include <array>
@@ -2956,6 +2957,11 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
//printf("%4d %s(%s) on device %d. time = %ld\n", i, ggml_op_name(dst->op), dst->name, ctx.device, ggml_time_us());
switch (dst->op) {
case GGML_OP_REDUCE:
ggml_cuda_op_reduce(ctx, dst);
break;
case GGML_OP_FAKE_CPY:
break;
case GGML_OP_ARGMAX:
ggml_cuda_argmax(ctx, dst);
break;
@@ -4066,6 +4072,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
}
return false;
} break;
case GGML_OP_REDUCE:
case GGML_OP_FAKE_CPY:
case GGML_OP_ARGMAX:
return true;
case GGML_OP_HADAMARD:

View File

@@ -17,7 +17,7 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(nhave >=2 && nhave <= nreduce);
//printf("============================== %s on device %d\n", __func__, ctx.device);
//printf("============================== %s on device %d with %d sources\n", __func__, ctx.device, nreduce);
#ifdef GGML_USE_NCCL
auto & info = ggml_cuda_info();

View File

@@ -4290,10 +4290,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS",
"CROSS_ENTROPY_LOSS_BACK",
"GLU",
"REDUCE",
"FAKE_CPY",
"GLU",
};
static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94");
@@ -6080,6 +6080,7 @@ struct ggml_tensor * ggml_reduce(
if (a[j]) ++nhave;
}
GGML_ASSERT(nhave > 1);
result->op = GGML_OP_REDUCE;
result->op_params[0] = (int)op;
result->op_params[1] = n;
result->op_params[2] = nhave;
@@ -6091,9 +6092,9 @@ struct ggml_tensor * ggml_fake_cpy(
struct ggml_tensor * dst,
struct ggml_tensor * src) {
struct ggml_tensor * result = ggml_view_tensor(ctx, dst);
result->op = GGML_OP_FAKE_CPY;
result->src[0] = dst;
result->src[1] = src;
result->op = GGML_OP_FAKE_CPY;
return result;
}
@@ -8471,6 +8472,21 @@ struct ggml_tensor * ggml_get_rows(
if (a->type == GGML_TYPE_I32) {
type = a->type;
}
if (a->op == GGML_OP_REDUCE) {
//printf("======================= %s(%s)\n", __func__, a->name);
struct ggml_tensor * result = NULL;
for (int j = a->op_params[1]-1; j >= 0; --j) {
if (a->src[j]) {
struct ggml_tensor * aj = ggml_get_rows(ctx, a->src[j], b);
if (result == NULL) result = ggml_view_tensor(ctx, aj);
result->src[j] = aj;
}
}
GGML_ASSERT(result);
return result;
}
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
result->op = GGML_OP_GET_ROWS;

View File

@@ -642,6 +642,7 @@ ggml_tensor * llm_build_context::llm_build_ffn(
if (!up_b && !up_s && !gate_b && !gate_s && !down_b && !down_s &&
up->extra && gate->extra && down->extra && type_gate == LLM_FFN_PAR &&
(type_op == LLM_FFN_SILU || type_op == LLM_FFN_RELU || (type_op == LLM_FFN_GELU && !act_scales))) {
//printf("%s: %s\n", __func__, ggml_op_name(input->op));
auto unary_op = type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU :
type_op == LLM_FFN_RELU ? GGML_UNARY_OP_RELU : GGML_UNARY_OP_GELU;
auto u = (ggml_split_tensor_t *)up->extra;
@@ -657,7 +658,18 @@ ggml_tensor * llm_build_context::llm_build_ffn(
auto split_d = d->splits[id];
GGML_ASSERT((!split_u && !split_g && !split_d) || (split_u && split_g && split_d));
if (!split_u) continue;
//auto cur = input->op == GGML_OP_REDUCE ? ggml_fake_cpy(ctx, input->src[id], input) : input;
auto cur = input;
if (input->op == GGML_OP_REDUCE) {
auto view_src = input->view_src;
GGML_ASSERT(view_src);
cur = input->src[id];
if (cur == view_src) {
//printf("%s: Setting input to %s for id = %d\n", __func__, view_src->name, id);
cur = input;
}
}
//auto cur = input->op == GGML_OP_REDUCE ? input->src[id] : input;
if (ffn_norm && ffn_norm->extra) {
auto norm = (ggml_split_tensor_t *)ffn_norm->extra;
GGML_ASSERT(norm->splits[id]);
@@ -688,21 +700,25 @@ ggml_tensor * llm_build_context::llm_build_ffn(
cb(ffn.back(), "ffn_with_inp", il);
}
if (ffn.size() == 1) return ffn.front();
auto cur = ggml_add(ctx, ffn[0], ffn[1]);
cb(cur, "combine_ffn", il);
cur->op_params[0] = 0xff;
for (int id = 2; id < int(ffn.size()); ++id) {
cur = ggml_add(ctx, cur, ffn[id]);
cb(cur, "combine_ffn", il);
}
if (ffn.size() > 2) {
cur->op_params[0] = 0xff;
}
//if (cur->type != GGML_TYPE_F32) {
// cur = ggml_cast(ctx, cur, GGML_TYPE_F32);
//}
auto cur = ggml_reduce(ctx, ffn.data(), u->n_device, GGML_OP_ADD);
cb(cur, "ffn_combined", il);
ggml_build_forward_expand(graph, cur);
return cur;
//auto cur = ggml_add(ctx, ffn[0], ffn[1]);
//cb(cur, "combine_ffn", il);
//cur->op_params[0] = 0xff;
//for (int id = 2; id < int(ffn.size()); ++id) {
// cur = ggml_add(ctx, cur, ffn[id]);
// cb(cur, "combine_ffn", il);
//}
//if (ffn.size() > 2) {
// cur->op_params[0] = 0xff;
//}
////if (cur->type != GGML_TYPE_F32) {
//// cur = ggml_cast(ctx, cur, GGML_TYPE_F32);
////}
//return cur;
}
if (ffn_norm) {
@@ -1822,6 +1838,7 @@ ggml_cgraph * llm_build_context::build_llama() {
Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il, nullptr,
this_n_swa);
}
//printf("%s: attn result for layer %d is %s, %s\n", __func__, il, cur->name, ggml_op_name(cur->op));
if (il == n_layer - 1) {
// skip computing output for unused tokens
@@ -1829,7 +1846,7 @@ ggml_cgraph * llm_build_context::build_llama() {
n_tokens = n_outputs;
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
cb(cur, "last_attn", il);
if (use_rope) {
if (!use_rope) {
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
cb(inpSA, "last_ffn_inp", il);
}
@@ -1906,6 +1923,7 @@ ggml_cgraph * llm_build_context::build_llama() {
cb, il, gf, true);
cb(cur, "ffn_moe_out", il);
}
//printf("%s: ffn result for layer %d is %s, %s\n", __func__, il, cur->name, ggml_op_name(cur->op));
// For Granite architecture
if (hparams.f_residual_scale) {
@@ -9344,6 +9362,7 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
if (!model.layers[il].wqkv && !model.layers[il].wqk && cparams.flash_attn &&
model.layers[il].wq->extra && model.layers[il].wk->extra && model.layers[il].wv->extra && model.layers[il].wo->extra) {
if (kv_self.k_l[il]->extra && kv_self.v_l[il]->extra) {
//printf("%s: %s\n", __func__, ggml_op_name(input->op));
ggml_split_tensor_t * attn_norm = the_attn_norm ? (ggml_split_tensor_t *)the_attn_norm->extra : nullptr;
auto wq = (ggml_split_tensor_t *)model.layers[il].wq->extra;
auto wk = (ggml_split_tensor_t *)model.layers[il].wk->extra;
@@ -9382,7 +9401,18 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
GGML_ASSERT((!split_wq && !split_wk && !split_wv && !split_wo && !split_kl && !split_vl) ||
(split_wq && split_wk && split_wv && split_wo && split_kl && split_vl));
if (!split_wq) continue;
//auto cur = input = input->op == GGML_OP_REDUCE ? ggml_fake_cpy(ctx0, input->src[id], input) : input;
auto cur = input;
if (input->op == GGML_OP_REDUCE) {
auto view_src = input->view_src;
GGML_ASSERT(view_src);
cur = input->src[id];
if (cur == view_src) {
//printf("%s: Setting input to %s for id = %d\n", __func__, view_src->name, id);
cur = input;
}
}
//auto cur = input = input->op == GGML_OP_REDUCE ? input->src[id] : input;
if (attn_norm) {
auto split_norm = attn_norm->splits[id];
cur = llm_build_norm(ctx0, cur, hparams, split_norm, NULL, LLM_NORM_RMS, cb, il);
@@ -9522,36 +9552,40 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
cb(attn.back(), "attn_out_with_input", il);
}
if (attn.size() == 1) return attn.front();
//if (attn.size() > 2 && attn.size()%2 == 0) {
// for (int id = 0; id < int(attn.size()/2); ++id) {
// attn[id] = ggml_add(ctx0, attn[2*id+0], attn[2*id+1]);
// attn[id]->op_params[0] = 0xff;
// }
// attn.resize(attn.size()/2);
// auto cur = ggml_add(ctx0, attn[0], attn[1]);
// cur->op_params[0] = 0xff;
// cur->op_params[0] = 0xff;
// for (int id = 2; id < (int)attn.size(); ++id) {
// cur = ggml_add(ctx0, cur, attn[id]);
// cb(cur, "combine_attn", il);
// }
// return cur;
//}
auto cur = ggml_add(ctx0, attn[0], attn[1]);
cb(cur, "combine_attn", il);
cur->op_params[0] = 0xff;
for (int id = 2; id < (int)attn.size(); ++id) {
cur = ggml_add(ctx0, cur, attn[id]);
cb(cur, "combine_attn", il);
}
if (attn.size() > 2) {
cur->op_params[0] = 0xff;
}
//if (add_input) {
// cur = ggml_add(ctx0, cur, input);
// cb(cur, "combine_attn_inp", il);
//}
auto cur = ggml_reduce(ctx0, attn.data(), wq->n_device, GGML_OP_ADD);
ggml_build_forward_expand(gf, cur);
cb(cur, "attn_combined", il);
return cur;
////if (attn.size() > 2 && attn.size()%2 == 0) {
//// for (int id = 0; id < int(attn.size()/2); ++id) {
//// attn[id] = ggml_add(ctx0, attn[2*id+0], attn[2*id+1]);
//// attn[id]->op_params[0] = 0xff;
//// }
//// attn.resize(attn.size()/2);
//// auto cur = ggml_add(ctx0, attn[0], attn[1]);
//// cur->op_params[0] = 0xff;
//// cur->op_params[0] = 0xff;
//// for (int id = 2; id < (int)attn.size(); ++id) {
//// cur = ggml_add(ctx0, cur, attn[id]);
//// cb(cur, "combine_attn", il);
//// }
//// return cur;
////}
//auto cur = ggml_add(ctx0, attn[0], attn[1]);
//cb(cur, "combine_attn", il);
//cur->op_params[0] = 0xff;
//for (int id = 2; id < (int)attn.size(); ++id) {
// cur = ggml_add(ctx0, cur, attn[id]);
// cb(cur, "combine_attn", il);
//}
//if (attn.size() > 2) {
// cur->op_params[0] = 0xff;
//}
////if (add_input) {
//// cur = ggml_add(ctx0, cur, input);
//// cb(cur, "combine_attn_inp", il);
////}
//return cur;
}
}