Graph parallel for Step-3.5-Flash (#1236)

* WIP

* This works but is slow

* Turn off the up / gate clamps for now

* OK we need the clamping

* Fuse the clamp (CUDA)

* Fuse the clamp (CPU)

* WIP

* Be able to use merged q, k, v

* Be able to use merged up/gate experts

* Fuse the clamp (CUDA mmvq)

* WIP: graph parallel for Step-3.5

* WIP

* This should be it

* Cleanup

* Fix merge
This commit is contained in:
Kawrakow
2026-02-06 06:56:51 +02:00
committed by GitHub
parent 5a44324e4a
commit 81ea911f0d
6 changed files with 113 additions and 116 deletions

View File

@@ -217,7 +217,7 @@ void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tens
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
if (gqa_ratio == 12 && Q->ne[1] == 1 && K->ne[1]*K->ne[2] >= 65536) {
if (false && gqa_ratio == 12 && Q->ne[1] == 1 && K->ne[1]*K->ne[2] >= 65536) {
// This is a hack to improve GLM-4.5/4.6/4.7/AIR TG performance
glm45_flash_attention(ctx, dst);
return;

View File

@@ -702,6 +702,10 @@ ggml_tensor * llm_build_context::llm_build_ffn(
cur = do_split_norm(ctx, cur, ffn_norm, lctx.model.hparams, cb, id, il_cb, is_norm);
cur = ggml_fused_up_gate(ctx, split_u, split_g, cur, unary_op);
cb(cur, "ffn_up_gate", il_cb);
if (lctx.model.arch == LLM_ARCH_STEP35) {
//printf("%s(%d): limits = %g\n", __func__, il, lctx.model.hparams.swiglu_limits[il]);
*(float *)(cur->op_params + 1) = lctx.model.hparams.swiglu_limits[il];
}
cur = llm_build_lora_mm(lctx, ctx, split_d, cur);
cb(cur, "ffn_down", il_cb);
if (lctx.model.arch == LLM_ARCH_GLM4 || lctx.model.arch == LLM_ARCH_GLM4_MOE) {
@@ -3587,132 +3591,60 @@ ggml_cgraph * llm_build_context::build_step35() {
auto KQ_mask = build_inp_KQ_mask();
auto KQ_mask_swa = build_inp_KQ_mask_swa();
//const float kq_scale = 1.0f / sqrtf(float(n_rot));
const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k));
for (int il = 0; il < n_layer; ++il) {
bool is_swa = hparams.swa_layers[il];
ggml_tensor * inpSA = inpL;
const uint32_t n_head_l = hparams.n_head(il);
const float freq_base_l = hparams.has_rope_freq_base_per_layer ? hparams.rope_freq_base_per_layer[il] :
is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
cur = inpL;
// self-attention
{
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
auto & layer = const_cast<llama_layer&>(model.layers[il]);
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, model.layers[il].bqkv,
model.layers[il].wqk, model.layers[il].bqk,
model.layers[il].wq, model.layers[il].bq,
model.layers[il].wk, model.layers[il].bk,
model.layers[il].wv, model.layers[il].bv,
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0.f, il);
ggml_tensor * rope_factors = nullptr;
const uint32_t apply_mask = hparams.rope_scaling_apply_mask;
if ((is_swa && (apply_mask & 0x2)) || (!is_swa && (apply_mask & 0x1))) {
rope_factors = build_rope_factors(il);
}
const int64_t n_rot_l = hparams.rope_n_rot(il);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors,
n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors,
n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur_pos", il);
cb(Kcur, "Kcur_pos", il);
const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k));
auto attn_out = llm_build_kv(ctx0, lctx, kv_self, gf, nullptr, nullptr, // i.e., do not multiply with wo
Kcur, Vcur, Qcur, is_swa ? KQ_mask_swa : KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il,
nullptr, is_swa ? hparams.n_swa : 0);
cb(attn_out, "attn_out", il);
// head-wise attention gate: sigmoid(g_proj(x)) in torch
if (model.layers[il].wqkv_gate) {
auto gate = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv_gate, cur); // [n_head_l, n_tokens]
cb(gate, "attn_gate", il);
gate = ggml_sigmoid(ctx0, gate);
cb(gate, "attn_gate_sigmoid", il);
// reshape + broadcast to [n_embd_head_v, n_head_l, n_tokens]
ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, attn_out, n_embd_head_v, n_head_l, n_tokens);
ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate, 1, n_head_l, n_tokens);
gate_3d = ggml_repeat(ctx0, gate_3d, attn_3d);
cb(gate_3d, "attn_gate_bcast", il);
attn_3d = ggml_mul(ctx0, attn_3d, gate_3d);
cb(attn_3d, "attn_gated_3d", il);
//attn_out = ggml_cont_2d(ctx0, ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens),
// n_embd_head_v * n_head_l, n_tokens);
attn_out = ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens);
cb(attn_out, "attn_gated", il);
}
// output projection
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, attn_out);
cb(cur, "attn_proj", il);
ggml_tensor * rope_factors = nullptr;
const uint32_t apply_mask = hparams.rope_scaling_apply_mask;
if ((is_swa && (apply_mask & 0x2)) || (!is_swa && (apply_mask & 0x1))) {
rope_factors = build_rope_factors(il);
}
if (il == n_layer - 1 && inp_out_ids && n_tokens > 1) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
auto rope_freqs = layer.rope_freqs;
layer.rope_freqs = nullptr;
cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, il == n_layer - 1 ? inp_out_ids : nullptr,
rope_factors, is_swa ? KQ_mask_swa : KQ_mask, nullptr, nullptr, kq_scale, 0.0f, is_swa ? hparams.n_swa : 0,
il, true, false, true);
layer.rope_freqs = rope_freqs;
cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
// feed-forward
if (model.layers[il].ffn_gate_inp == nullptr) {
// dense MLP
cur = llm_build_ffn(ctx0, lctx, nullptr, cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, nullptr,
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, nullptr,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, nullptr,
// dense FFN
cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
nullptr,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, true);
cb(cur, "ffn_out", il);
} else {
// MoE routed experts
const bool norm_w = hparams.expert_weights_norm;
const float w_scale = hparams.expert_weights_scale;
const bool scale_w = w_scale != 0.0f;
ggml_tensor * moe_out = llm_build_moe_ffn(ctx0, lctx, cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
cur = llm_build_std_moe_ffn(ctx0, lctx, model.layers[il].ffn_norm, cur,
model.layers[il].ffn_gate_inp, model.layers[il].ffn_gate_inp_b,
model.layers[il].ffn_up_exps, model.layers[il].ffn_up_exps_b,
model.layers[il].ffn_gate_exps, model.layers[il].ffn_gate_exps_b,
model.layers[il].ffn_down_exps, model.layers[il].ffn_down_exps_b,
model.layers[il].ffn_exp_probs_b,
model.layers[il].ffn_up_shexp, nullptr, // we don't have shared expert biases?
model.layers[il].ffn_gate_shexp, nullptr,
model.layers[il].ffn_down_shexp, nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU,
norm_w, scale_w, w_scale,
LLM_FFN_SILU, norm_w, scale_w, w_scale,
LLM_EXPERT_GATING_FUNC_SIGMOID,
cb, il, gf, false, model.layers[il].ffn_up_gate_exps);
cb(moe_out, "ffn_moe_out", il);
// shared expert MLP (always added on MoE layers in Step35)
ggml_tensor * sh_out = llm_build_ffn(ctx0, lctx, nullptr, cur,
model.layers[il].ffn_up_shexp, nullptr, nullptr,
model.layers[il].ffn_gate_shexp, nullptr, nullptr,
model.layers[il].ffn_down_shexp, nullptr, nullptr,
nullptr,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf);
cb(sh_out, "ffn_shared_out", il);
cur = ggml_add(ctx0, moe_out, sh_out);
cb(cur, "ffn_out", il);
//(llm_expert_gating_func_type) hparams.expert_gating_func,
LLM_FFN_SILU, cb, il, gf, true, model.layers[il].ffn_up_gate_exps);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out_with_inp", il);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
inpL = cur;
}
cur = inpL;
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
// lm_head
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cur = build_output(lctx, ctx0, inpL, model.output, model.output_norm, cb);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
@@ -9554,6 +9486,10 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
float freq_base_l = n_swa > 0 ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
float freq_scale_l = n_swa > 0 ? hparams.rope_freq_scale_train_swa : hparams.rope_freq_scale_train;
if (hparams.has_rope_freq_base_per_layer) {
freq_base_l = hparams.rope_freq_base_per_layer[il];
}
int n_rot_l = lctx.model.hparams.rope_n_rot(il);
#ifdef GGML_USE_VULKAN
constexpr bool use_f32_precision = true;
#else
@@ -9616,6 +9552,7 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
auto cur = get_input_tensor_sm_graph(ctx0, input, id);
auto input_id = cur;
cur = do_split_norm(ctx0, cur, the_attn_norm, lctx.model.hparams, cb, id, il_cb, is_norm);
auto input_normed = cur;
auto the_q_norm = model.layers[il].attn_q_norm ? model.layers[il].attn_q_norm->extra ?
((ggml_split_tensor_t *)model.layers[il].attn_q_norm->extra)->splits[id] : model.layers[il].attn_q_norm : nullptr;
auto the_k_norm = model.layers[il].attn_k_norm ? model.layers[il].attn_k_norm->extra ?
@@ -9626,7 +9563,13 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
split_wv, bv ? bv->splits[id] : nullptr,
the_q_norm, the_k_norm, f_attn_scale, il, add_graph_split);
auto rope_factors = rope_factors_in;
if (!rope_factors && model.layers[il].rope_freqs && model.layers[il].rope_freqs->extra) {
if (rope_factors) {
GGML_ASSERT(rope_factors->extra);
rope_factors = ((ggml_split_tensor_t *)rope_factors->extra)->splits[id];
GGML_ASSERT(rope_factors);
}
else if (model.layers[il].rope_freqs && model.layers[il].rope_freqs->extra) {
printf("%s(%d, %d): using model.layers[il].rope_freqs as rope_factors_in was null\n", __func__, il, id);
auto extra = (ggml_split_tensor_t *)model.layers[il].rope_freqs->extra;
rope_factors = extra->splits[id];
}
@@ -9635,15 +9578,15 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
int sections[4];
std::copy(hparams.rope_sections.begin(), hparams.rope_sections.begin() + GGML_MROPE_SECTIONS, sections);
Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, rope_factors,
n_rot, sections, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
n_rot_l, sections, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, rope_factors,
n_rot, sections, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
n_rot_l, sections, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
} else {
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
}
}
@@ -9732,6 +9675,21 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
}
if (model.layers[il].wqkv_gate) {
auto wqkv_gate = (ggml_split_tensor_t *)model.layers[il].wqkv_gate->extra;
GGML_ASSERT(wqkv_gate && wqkv_gate->splits[id]);
auto gate = llm_build_lora_mm(lctx, ctx0, wqkv_gate->splits[id], input_normed);
cb(gate, "attn_gate", il_cb);
gate = ggml_sigmoid(ctx0, gate);
cb(gate, "attn_gate_sigmoid", il_cb);
int nh = split_wo->ne[0]/n_embd_head_v;
auto attn_3d = ggml_reshape_3d(ctx0, cur, n_embd_head_v, nh, n_tokens);
auto gate_3d = ggml_reshape_3d(ctx0, gate, 1, nh, n_tokens);
gate_3d = ggml_repeat(ctx0, gate_3d, attn_3d);
cur = ggml_mul(ctx0, attn_3d, gate_3d);
cb(attn_3d, "attn_gated_3d", il_cb);
}
cur = ggml_reshape_2d(ctx0, cur, split_wo->ne[0], n_tokens);
cb(cur, "flash_attn_reshaped", il_cb);
@@ -9776,6 +9734,7 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
cur = llm_build_norm(ctx0, cur, hparams, the_attn_norm, NULL, is_norm ? LLM_NORM : LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
}
auto input_normed = cur;
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, model.layers[il].bqkv,
@@ -9788,15 +9747,15 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
int sections[4];
std::copy(hparams.rope_sections.begin(), hparams.rope_sections.begin() + GGML_MROPE_SECTIONS, sections);
Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, rope_factors_in,
n_rot, sections, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
n_rot_l, sections, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, rope_factors_in,
n_rot, sections, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
n_rot_l, sections, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
} else {
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors_in, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors_in, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, rope_factors_in, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, rope_factors_in, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
}
}
@@ -9808,9 +9767,35 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
cb(Qcur, "Qcur_temp_scaled", il);
}
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, KQ_scale, cb, il, sinks, n_swa);
if (auto wqkv_gate = model.layers[il].wqkv_gate; wqkv_gate != nullptr) {
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
nullptr, nullptr,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, KQ_scale, cb, il, sinks, n_swa);
cb(cur, "wqkv", il);
auto gate = llm_build_lora_mm(lctx, ctx0, wqkv_gate, input_normed); // [n_head_l, n_tokens]
cb(gate, "attn_gate", il);
gate = ggml_sigmoid(ctx0, gate);
cb(gate, "attn_gate_sigmoid", il);
// reshape + broadcast to [n_embd_head_v, n_head_l, n_tokens]
int n_head_l = hparams.n_head(il);
ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, cur, n_embd_head_v, n_head_l, n_tokens);
ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate, 1, n_head_l, n_tokens);
gate_3d = ggml_repeat(ctx0, gate_3d, attn_3d);
cb(gate_3d, "attn_gate_bcast", il);
attn_3d = ggml_mul(ctx0, attn_3d, gate_3d);
cb(attn_3d, "attn_gated_3d", il);
cur = ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens);
cb(cur, "attn_gated", il);
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
if (model.layers[il].bo) {
cur = ggml_add(ctx0, cur, model.layers[il].bo);
}
cb(cur, "attn_out", il);
} else {
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, KQ_scale, cb, il, sinks, n_swa);
}
if (inp_out_ids) { // && ggml_nrows(inp_out_ids) > 1) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);

View File

@@ -187,7 +187,7 @@ struct llama_hparams {
if (il < n_layer) {
return n_head_arr[il];
}
printf("%s: Oops, il = %d\n", __func__, il);
GGML_ABORT("fatal error");
}

View File

@@ -3285,6 +3285,16 @@ bool create_tensors_helper::create_tensors() {
}
prepare_split_tensors(0, ctx_split, layer.attn_sinks, layer.split_sinks, split_sinks, mem_used);
}
if (layer.wqkv_gate) {
auto wqkv_gate_split = split_kq;
LLAMA_LOG_DEBUG("=================== wqkv_gate_split:");
for (auto & s : wqkv_gate_split) {
s /= hparams.n_embd_head_k;
LLAMA_LOG_DEBUG(" %d", s);
}
LLAMA_LOG_DEBUG("\n");
prepare_split_tensors(1, ctx_split, layer.wqkv_gate, layer.split_wqkv_gate, wqkv_gate_split, mem_used);
}
for (auto & s : split_kq) s /= gqa_ratio;
for (auto & s : split_vo) s /= gqa_ratio;
if (layer.attn_k_norm && layer.attn_k_norm->ne[0] == layer.wk->ne[1]) {

View File

@@ -204,6 +204,7 @@ struct llama_layer {
llama_split_tensor split_q_norm;
llama_split_tensor split_k_norm;
llama_split_tensor split_sinks;
llama_split_tensor split_wqkv_gate;
// relative position bias
struct ggml_tensor * attn_rel_b = nullptr;

View File

@@ -1756,6 +1756,7 @@ static bool is_model_split_supported(const llama_model & model) {
LLM_ARCH_ERNIE4_5_MOE,
LLM_ARCH_MINIMAX_M2,
LLM_ARCH_SEED_OSS,
LLM_ARCH_STEP35,
};
auto it = k_supported.find(model.arch);
return it != k_supported.end();