Step-3.5-Flash support (#1231)

* WIP

* This works but is slow

* Turn off the up / gate clamps for now

* OK we need the clamping

* Fuse the clamp (CUDA)

* Fuse the clamp (CPU)

* WIP

* Be able to use merged q, k, v

* Be able to use merged up/gate experts

* Fuse the clamp (CUDA mmvq)
This commit is contained in:
Kawrakow
2026-02-05 08:13:22 +02:00
committed by GitHub
parent 8d952ff183
commit 9c1c74acda
22 changed files with 487 additions and 69 deletions

View File

@@ -755,6 +755,9 @@ ggml_tensor * llm_build_context::llm_build_ffn(
type_op == LLM_FFN_RELU ? GGML_UNARY_OP_RELU : GGML_UNARY_OP_GELU;
cur = ggml_fused_up_gate(ctx, up, gate, cur, unary_op);
cb(cur, "ffn_up_gate", il);
if (lctx.model.arch == LLM_ARCH_STEP35) {
*(float *)(cur->op_params + 1) = lctx.model.hparams.swiglu_limits_shared[il];
}
if (down) {
cur = llm_build_lora_mm(lctx, ctx, down, cur);
if (lctx.model.arch == LLM_ARCH_GLM4 || lctx.model.arch == LLM_ARCH_GLM4_MOE) {
@@ -828,12 +831,21 @@ ggml_tensor * llm_build_context::llm_build_ffn(
(type_op == LLM_FFN_SILU || type_op == LLM_FFN_RELU || (type_op == LLM_FFN_GELU && !act_scales))) {
cur = ggml_fused_mul_unary(ctx, cur, tmp, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU :
type_op == LLM_FFN_RELU ? GGML_UNARY_OP_RELU : GGML_UNARY_OP_GELU);
if (lctx.model.arch == LLM_ARCH_STEP35) {
*((float *)(cur->op_params + 1)) = lctx.model.hparams.swiglu_limits_shared[il];
}
}
else {
switch (type_op) {
case LLM_FFN_SILU:
{
if (lctx.model.arch == LLM_ARCH_STEP35) {
cur = ggml_fused_mul_unary(ctx, cur, up, GGML_UNARY_OP_SILU);
*(float *)(cur->op_params + 1) = lctx.model.hparams.swiglu_limits_shared[il];
type_gate = LLM_FFN_SEQ;
break;
}
cur = ggml_silu(ctx, cur);
cb(cur, "ffn_silu", il);
} break;
@@ -1003,7 +1015,7 @@ llm_expert_gating_func_type gating_op,
ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens]
cb(weights_sum, "ffn_moe_weights_sum", il);
if (lctx.model.arch == LLM_ARCH_BAILINGMOE2) {
if (lctx.model.arch == LLM_ARCH_BAILINGMOE2 || lctx.model.arch == LLM_ARCH_STEP35) {
weights_sum = ggml_scale_bias(ctx, weights_sum, 1.0, 1e-20);
cb(weights_sum, "ffn_moe_weights_sum_biased", il);
}
@@ -1036,7 +1048,7 @@ llm_expert_gating_func_type gating_op,
// Hence, if we have biases, we cannot use fmoe.
//
//bool can_use_fmoe = !up_exps_b && !gate_exps_b && (type_op == LLM_FFN_SILU || type_op == LLM_FFN_GELU);
bool can_use_fmoe = type_op == LLM_FFN_SILU || type_op == LLM_FFN_GELU || type_op == LLM_FFN_SWIGLU_OAI_MOE;
bool can_use_fmoe = (type_op == LLM_FFN_SILU || type_op == LLM_FFN_GELU || type_op == LLM_FFN_SWIGLU_OAI_MOE);
ggml_tensor * par;
if (can_use_fmoe && up_gate_exps) {
@@ -1049,6 +1061,9 @@ llm_expert_gating_func_type gating_op,
par = ggml_moe_up_gate(ctx, up_gate_exps, nullptr, cur, selected_experts,
type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU);
}
if (lctx.model.arch == LLM_ARCH_STEP35) {
*((float *)(par->op_params + 1)) = lctx.model.hparams.swiglu_limits[il];
}
} else {
GGML_ASSERT(!up_gate_exps && !up_gate_exps_b);
@@ -1062,6 +1077,9 @@ llm_expert_gating_func_type gating_op,
par = ggml_moe_up_gate(ctx, up_exps, gate_exps, cur, selected_experts,
type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU);
}
if (lctx.model.arch == LLM_ARCH_STEP35) {
*(float *)(par->op_params + 1) = lctx.model.hparams.swiglu_limits[il];
}
} else {
ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(up, "ffn_moe_up", il);
@@ -1087,6 +1105,9 @@ llm_expert_gating_func_type gating_op,
if (type_op == LLM_FFN_SILU || type_op == LLM_FFN_GELU) {
par = ggml_fused_mul_unary(ctx, gate, up, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU);
if (lctx.model.arch == LLM_ARCH_STEP35) {
*((float *)(par->op_params + 1)) = lctx.model.hparams.swiglu_limits[il];
}
} else if (type_op == LLM_FFN_SWIGLU_OAI_MOE) {
constexpr float alpha = 1.702f;
constexpr float limit = 7.0f;
@@ -1655,8 +1676,10 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_buil
ggml_tensor * wk, ggml_tensor * bk,
ggml_tensor * wv, ggml_tensor * bv,
ggml_tensor * q_norm, ggml_tensor * k_norm, float attention_scale, int il, bool add_graph_split) const {
int n_head = hparams.n_head(il);
int n_head_kv = hparams.n_head_kv(il);
const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
const int64_t n_embd_gqa = hparams.n_embd_v_gqa(il);
if (wqkv) {
auto qkv = llm_build_lora_mm(lctx, ctx0, wqkv, cur);
if (add_graph_split) {
@@ -3555,6 +3578,147 @@ ggml_cgraph * llm_build_context::build_seedoss() {
return gf;
}
ggml_cgraph * llm_build_context::build_step35() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
ggml_tensor * cur;
auto inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
auto inp_pos = build_inp_pos();
auto inp_out_ids = build_inp_out_ids();
auto KQ_mask = build_inp_KQ_mask();
auto KQ_mask_swa = build_inp_KQ_mask_swa();
//const float kq_scale = 1.0f / sqrtf(float(n_rot));
for (int il = 0; il < n_layer; ++il) {
bool is_swa = hparams.swa_layers[il];
ggml_tensor * inpSA = inpL;
const uint32_t n_head_l = hparams.n_head(il);
const float freq_base_l = hparams.has_rope_freq_base_per_layer ? hparams.rope_freq_base_per_layer[il] :
is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
cur = inpL;
// self-attention
{
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, model.layers[il].bqkv,
model.layers[il].wqk, model.layers[il].bqk,
model.layers[il].wq, model.layers[il].bq,
model.layers[il].wk, model.layers[il].bk,
model.layers[il].wv, model.layers[il].bv,
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0.f, il);
ggml_tensor * rope_factors = nullptr;
const uint32_t apply_mask = hparams.rope_scaling_apply_mask;
if ((is_swa && (apply_mask & 0x2)) || (!is_swa && (apply_mask & 0x1))) {
rope_factors = build_rope_factors(il);
}
const int64_t n_rot_l = hparams.rope_n_rot(il);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors,
n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors,
n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur_pos", il);
cb(Kcur, "Kcur_pos", il);
const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k));
auto attn_out = llm_build_kv(ctx0, lctx, kv_self, gf, nullptr, nullptr, // i.e., do not multiply with wo
Kcur, Vcur, Qcur, is_swa ? KQ_mask_swa : KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il,
nullptr, is_swa ? hparams.n_swa : 0);
cb(attn_out, "attn_out", il);
// head-wise attention gate: sigmoid(g_proj(x)) in torch
if (model.layers[il].wqkv_gate) {
auto gate = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv_gate, cur); // [n_head_l, n_tokens]
cb(gate, "attn_gate", il);
gate = ggml_sigmoid(ctx0, gate);
cb(gate, "attn_gate_sigmoid", il);
// reshape + broadcast to [n_embd_head_v, n_head_l, n_tokens]
ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, attn_out, n_embd_head_v, n_head_l, n_tokens);
ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate, 1, n_head_l, n_tokens);
gate_3d = ggml_repeat(ctx0, gate_3d, attn_3d);
cb(gate_3d, "attn_gate_bcast", il);
attn_3d = ggml_mul(ctx0, attn_3d, gate_3d);
cb(attn_3d, "attn_gated_3d", il);
//attn_out = ggml_cont_2d(ctx0, ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens),
// n_embd_head_v * n_head_l, n_tokens);
attn_out = ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens);
cb(attn_out, "attn_gated", il);
}
// output projection
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, attn_out);
cb(cur, "attn_proj", il);
}
if (il == n_layer - 1 && inp_out_ids && n_tokens > 1) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
// feed-forward
if (model.layers[il].ffn_gate_inp == nullptr) {
// dense MLP
cur = llm_build_ffn(ctx0, lctx, nullptr, cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, nullptr,
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, nullptr,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, nullptr,
nullptr,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
} else {
// MoE routed experts
const bool norm_w = hparams.expert_weights_norm;
const float w_scale = hparams.expert_weights_scale;
const bool scale_w = w_scale != 0.0f;
ggml_tensor * moe_out = llm_build_moe_ffn(ctx0, lctx, cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
model.layers[il].ffn_exp_probs_b,
n_expert, n_expert_used,
LLM_FFN_SILU,
norm_w, scale_w, w_scale,
LLM_EXPERT_GATING_FUNC_SIGMOID,
cb, il, gf, false, model.layers[il].ffn_up_gate_exps);
cb(moe_out, "ffn_moe_out", il);
// shared expert MLP (always added on MoE layers in Step35)
ggml_tensor * sh_out = llm_build_ffn(ctx0, lctx, nullptr, cur,
model.layers[il].ffn_up_shexp, nullptr, nullptr,
model.layers[il].ffn_gate_shexp, nullptr, nullptr,
model.layers[il].ffn_down_shexp, nullptr, nullptr,
nullptr,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf);
cb(sh_out, "ffn_shared_out", il);
cur = ggml_add(ctx0, moe_out, sh_out);
cb(cur, "ffn_out", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out_with_inp", il);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
inpL = cur;
}
cur = inpL;
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
// lm_head
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
ggml_cgraph * llm_build_context::build_qwen() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
@@ -9360,6 +9524,10 @@ ggml_cgraph * llm_build_context::llama_build_graph(
{
result = llm.build_seedoss();
} break;
case LLM_ARCH_STEP35:
{
result = llm.build_step35();
} break;
default:
GGML_ABORT("fatal error");
}