WIP: fix sm layer (dense)

This commit is contained in:
Iwan Kawrakow
2025-12-21 16:05:34 +00:00
parent 1fe53d2002
commit 5db8262d94

View File

@@ -707,29 +707,15 @@ ggml_tensor * llm_build_context::llm_build_ffn(
cb(cur, "ffn_combined", il);
ggml_build_forward_expand(graph, cur);
return cur;
//auto cur = ggml_add(ctx, ffn[0], ffn[1]);
//cb(cur, "combine_ffn", il);
//cur->op_params[0] = 0xff;
//for (int id = 2; id < int(ffn.size()); ++id) {
// cur = ggml_add(ctx, cur, ffn[id]);
// cb(cur, "combine_ffn", il);
//}
//if (ffn.size() > 2) {
// cur->op_params[0] = 0xff;
//}
////if (cur->type != GGML_TYPE_F32) {
//// cur = ggml_cast(ctx, cur, GGML_TYPE_F32);
////}
//return cur;
}
auto cur = input;
if (ffn_norm) {
input = llm_build_norm(ctx, input, lctx.model.hparams, ffn_norm, NULL, LLM_NORM_RMS, cb, il);
cur = llm_build_norm(ctx, cur, lctx.model.hparams, ffn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(input, "ffn_norm", il);
}
else if (input->type != GGML_TYPE_F32) {
input = ggml_cast(ctx, input, GGML_TYPE_F32);
if (cur->type != GGML_TYPE_F32) {
cur = ggml_cast(ctx, cur, GGML_TYPE_F32);
}
if (lctx.cparams.fused_up_gate &&
@@ -737,7 +723,7 @@ ggml_tensor * llm_build_context::llm_build_ffn(
(type_op == LLM_FFN_SILU || type_op == LLM_FFN_RELU || (type_op == LLM_FFN_GELU && !act_scales))) {
auto unary_op = type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU :
type_op == LLM_FFN_RELU ? GGML_UNARY_OP_RELU : GGML_UNARY_OP_GELU;
auto cur = ggml_fused_up_gate(ctx, up, gate, input, unary_op);
cur = ggml_fused_up_gate(ctx, up, gate, cur, unary_op);
cb(cur, "ffn_up_gate", il);
if (down) {
cur = llm_build_lora_mm(lctx, ctx, down, cur);
@@ -756,10 +742,14 @@ ggml_tensor * llm_build_context::llm_build_ffn(
cur = ggml_mul(ctx, cur, down_s);
cb(cur, "ffn_down_s", il);
}
if (add_input) {
cur = ggml_add(ctx, cur, input);
cb(cur, "ffn_out_with_inp", il);
}
return cur;
}
struct ggml_tensor * tmp = up ? llm_build_lora_mm(lctx, ctx, up, input) : input;
struct ggml_tensor * tmp = up ? llm_build_lora_mm(lctx, ctx, up, cur) : cur;
cb(tmp, "ffn_up", il);
if (up_b) {
@@ -772,7 +762,6 @@ ggml_tensor * llm_build_context::llm_build_ffn(
cb(tmp, "ffn_up_s", il);
}
auto cur = input;
if (gate) {
switch (type_gate) {
case LLM_FFN_SEQ: