WIP: it runs with wrong result

But it also looks like the backend scheduler is not going to help:
* It copies mask and input positions to GPU 0
* => RoPE ops must run on GPU 0
* => To proceed attn evaluation, GPU 1 must wait for GPU 0 to finish its
     entire attn calculation
* Same with FFN. The rms_norm gets scheduled on GPU 0. Hence, GPU 1 must
  wait for GPU 0 to finish its entore FFN calculation before it can
  start (as it needs to copy the result of rms_norm from GPU 0)
* => Seems useless without writing a bespoke TP scheduling
This commit is contained in:
Kawrakow
2025-11-26 09:27:12 +00:00
parent bc4be331ee
commit 5d68e4eb35
6 changed files with 322 additions and 66 deletions

View File

@@ -636,6 +636,44 @@ ggml_tensor * llm_build_context::llm_build_ffn(
llm_ffn_gate_type type_gate,
const llm_build_cb & cb, int il) {
if (!up_b && !up_s && !gate_b && !gate_s && !down_b && !down_s &&
up->extra && gate->extra && down->extra && type_gate == LLM_FFN_PAR &&
(type_op == LLM_FFN_SILU || type_op == LLM_FFN_RELU || (type_op == LLM_FFN_GELU && !act_scales))) {
auto unary_op = type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU :
type_op == LLM_FFN_RELU ? GGML_UNARY_OP_RELU : GGML_UNARY_OP_GELU;
auto u = (ggml_split_tensor_t *)up->extra;
auto g = (ggml_split_tensor_t *)gate->extra;
auto d = (ggml_split_tensor_t *)down->extra;
GGML_ASSERT(u->n_device == g->n_device && u->n_device == d->n_device);
std::vector<ggml_tensor *> ffn;
ffn.reserve(u->n_device);
for (int id = 0; id < u->n_device; ++id) {
int il_cb = 1000*id + il;
auto split_u = u->splits[id];
auto split_g = g->splits[id];
auto split_d = d->splits[id];
GGML_ASSERT((!split_u && !split_g && split_d) || (split_u && split_g && split_d));
if (!split_u) continue;
cur = ggml_fused_up_gate(ctx, split_u, split_g, cur, unary_op);
cb(cur, "ffn_up_gate", il_cb);
cur = llm_build_lora_mm(lctx, ctx, split_d, cur);
cb(cur, "ffn_down", il_cb);
if (lctx.model.arch == LLM_ARCH_GLM4 || lctx.model.arch == LLM_ARCH_GLM4_MOE) {
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
ffn.push_back(cur);
}
if (ffn.size() == 1) return ffn.front();
cur = ggml_add(ctx, ffn[0], ffn[1]);
cb(cur, "combine_ffn", il);
for (int id = 2; id < int(ffn.size()); ++id) {
cur = ggml_add(ctx, cur, ffn[id]);
cb(cur, "combine_ffn", il);
}
return cur;
}
if (lctx.cparams.fused_up_gate &&
up && gate && !up_b && !up_s && !gate_b && !gate_s && type_gate == LLM_FFN_PAR &&
(type_op == LLM_FFN_SILU || type_op == LLM_FFN_RELU || (type_op == LLM_FFN_GELU && !act_scales))) {
@@ -1243,7 +1281,7 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_buil
ggml_tensor * wq, ggml_tensor * bq,
ggml_tensor * wk, ggml_tensor * bk,
ggml_tensor * wv, ggml_tensor * bv,
float attention_scale, int il) {
float attention_scale, int il) const {
auto Qcur = llm_build_lora_mm(lctx, ctx0, wq, cur);
cb(Qcur, "Qcur", il);
auto Kcur = llm_build_lora_mm(lctx, ctx0, wk, cur);
@@ -1282,7 +1320,7 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_buil
ggml_tensor * wq, ggml_tensor * bq,
ggml_tensor * wk, ggml_tensor * bk,
ggml_tensor * wv, ggml_tensor * bv,
ggml_tensor * q_norm, ggml_tensor * k_norm, float attention_scale, int il) {
ggml_tensor * q_norm, ggml_tensor * k_norm, float attention_scale, int il) const {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
if (wqkv) {
@@ -1351,13 +1389,13 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_buil
}
auto [Q, K, V] = llm_build_mul_mat_qkv(gf, cur, wq, bq, wk, bk, wv, bv, attention_scale, il);
auto Qcur = ggml_reshape_3d(ctx0, Q, n_embd_head, n_head, n_tokens);
auto Qcur = ggml_reshape_3d(ctx0, Q, n_embd_head, Q->ne[0]/n_embd_head, n_tokens);
if (q_norm) {
Qcur = llm_build_norm(ctx0, Qcur, hparams, q_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il);
}
auto Kcur = ggml_reshape_3d(ctx0, K, n_embd_head, n_head_kv, n_tokens);
auto Kcur = ggml_reshape_3d(ctx0, K, n_embd_head, K->ne[0]/n_embd_head, n_tokens);
if (k_norm) {
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il);
cb(Kcur, "Kcur_normed", il);
@@ -1405,15 +1443,20 @@ ggml_cgraph * llm_build_context::build_llama() {
bool use_rope = model.arch == LLM_ARCH_LLAMA4 ? (il + 1) % hparams.n_no_rope_layer_step != 0 : true;
auto this_KQ_mask = hparams.n_swa > 0 && hparams.n_swa_pattern > 0 && il % hparams.n_swa_pattern < (hparams.n_swa_pattern - 1) ?
KQ_mask_swa : KQ_mask;
int this_n_swa = this_KQ_mask == KQ_mask_swa ? hparams.n_swa : 0;
// norm
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
// rope freq factors for llama3; may return nullptr for llama2 and other models
auto rope_factors = build_rope_factors(il);
// self-attention
{
// rope freq factors for llama3; may return nullptr for llama2 and other models
struct ggml_tensor * rope_factors = build_rope_factors(il);
if (use_rope) {
cur = build_std_attention(gf, cur, inp_pos, rope_factors, this_KQ_mask, nullptr, kq_scale, hparams.f_attention_scale, this_n_swa, il);
}
else {
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, model.layers[il].bqkv,
@@ -1450,7 +1493,7 @@ ggml_cgraph * llm_build_context::build_llama() {
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il, nullptr,
this_KQ_mask == KQ_mask_swa ? hparams.n_swa : 0);
this_n_swa);
}
if (il == n_layer - 1) {
@@ -1555,7 +1598,23 @@ ggml_cgraph * llm_build_context::build_llama() {
cb(cur, "result_norm", -1);
// lm_head
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
if (model.output->extra) {
auto output = (ggml_split_tensor_t *)model.output->extra;
std::vector<ggml_tensor *> o;
o.reserve(output->n_device);
for (int id = 0; id < output->n_device; ++id) {
auto split = output->splits[id];
if (!split) continue;
o.push_back(llm_build_lora_mm(lctx, ctx0, split, cur));
}
if (o.size() == 1) cur = o.front();
cur = ggml_concat(ctx0, o[0], o[1], 0);
for (int id = 2; id < int(o.size()); ++id) {
cur = ggml_concat(ctx0, cur, o[id], 0);
}
} else {
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
}
// For Granite architecture
if (hparams.f_logit_scale) {
@@ -3514,9 +3573,6 @@ ggml_cgraph * llm_build_context::build_qwen3() {
ggml_cgraph * llm_build_context::build_qwen3moe() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
// mutable variable, needed during the last layer of the computation to skip unused tokens
int32_t n_tokens = this->n_tokens;
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
@@ -3532,10 +3588,6 @@ ggml_cgraph * llm_build_context::build_qwen3moe() {
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
auto rope_cache = cparams.rope_cache && (rope_type == LLAMA_ROPE_TYPE_NEOX || rope_type == LLAMA_ROPE_TYPE_NORM) ?
ggml_rope_cache(ctx0, inp_pos, nullptr, n_embd_head, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow) : nullptr;
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
@@ -3543,35 +3595,11 @@ ggml_cgraph * llm_build_context::build_qwen3moe() {
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
// self_attention
{
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, nullptr,
model.layers[il].wqk, nullptr,
model.layers[il].wq, nullptr, model.layers[il].wk, nullptr, model.layers[il].wv, nullptr,
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, 0, il);
if (rope_cache) {
Qcur = ggml_rope_fast(ctx0, Qcur, rope_cache);
Kcur = ggml_rope_fast(ctx0, Kcur, rope_cache);
} else {
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
cur = build_std_attention(gf, cur, inp_pos, nullptr, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), 0.0f, 0, il);
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
n_tokens = n_outputs;
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
@@ -3583,8 +3611,7 @@ ggml_cgraph * llm_build_context::build_qwen3moe() {
cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
cur =
llm_build_moe_ffn(ctx0, lctx, cur,
cur = llm_build_moe_ffn(ctx0, lctx, cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
@@ -9010,3 +9037,151 @@ ggml_cgraph * llm_build_context::llama_build_graph(
return result;
}
ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * inp_pos, ggml_tensor * rope_factors,
ggml_tensor * KQ_mask, ggml_tensor * sinks, float KQ_scale, float f_attn_scale, int n_swa, int il) {
if (!model.layers[il].wqkv && !model.layers[il].wqk && cparams.flash_attn &&
model.layers[il].wq->extra && model.layers[il].wk->extra && model.layers[il].wv->extra && model.layers[il].wo->extra) {
if (kv_self.k_l[il]->extra && kv_self.v_l[il]->extra) {
auto wq = (ggml_split_tensor_t *)model.layers[il].wq->extra;
auto wk = (ggml_split_tensor_t *)model.layers[il].wk->extra;
auto wv = (ggml_split_tensor_t *)model.layers[il].wv->extra;
auto wo = (ggml_split_tensor_t *)model.layers[il].wo->extra;
GGML_ASSERT(wq->n_device == wk->n_device && wq->n_device == wv->n_device && wq->n_device == wo->n_device);
auto kl = (ggml_split_tensor_t *)kv_self.k_l[il]->extra;
auto vl = (ggml_split_tensor_t *)kv_self.v_l[il]->extra;
GGML_ASSERT(wq->n_device == kl->n_device && wq->n_device == vl->n_device);
std::vector<ggml_tensor*> attn; attn.reserve(wq->n_device);
for (int id = 0; id < wq->n_device; ++id) {
int il_cb = 1000*id + il;
auto split_wq = wq->splits[id];
auto split_wk = wk->splits[id];
auto split_wv = wv->splits[id];
auto split_wo = wo->splits[id];
auto split_kl = kl->splits[id];
auto split_vl = vl->splits[id];
GGML_ASSERT((!split_wq && !split_wk && !split_wv && !split_wo && !split_kl && !split_vl) ||
(split_wq && split_wk && split_wv && split_wo && split_kl && split_vl));
if (!split_wq) continue;
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, nullptr, nullptr, nullptr, nullptr,
split_wq, nullptr, split_wk, nullptr, split_wv, nullptr,
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, f_attn_scale, il_cb);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il_cb);
cb(Kcur, "Kcur", il_cb);
ggml_build_forward_expand(gf, Qcur);
ggml_build_forward_expand(gf, Kcur);
ggml_build_forward_expand(gf, Vcur);
const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_head_kv = split_wk->ne[1] / n_embd_head_k;
GGML_ASSERT(kv_self.size == cparams.n_ctx);
GGML_ASSERT(2*il+1 < (int)lctx.cache_copies.size());
auto k_row_size = ggml_row_size(split_kl->type, n_embd_head_k);
ggml_tensor * k_cache_view = ggml_view_2d(ctx0, split_kl, n_embd_head_k, n_tokens*n_head_kv,
k_row_size, k_row_size*n_head_kv*kv_head);
lctx.cache_copies[2*il+0].cpy = ggml_cpy(ctx0, Kcur, k_cache_view);
lctx.cache_copies[2*il+0].step = k_row_size*n_head_kv;
// note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(gf, lctx.cache_copies[2*il+0].cpy);
struct ggml_tensor * v_cache_view = nullptr;
if (cparams.flash_attn) {
v_cache_view = ggml_view_1d(ctx0, split_vl, n_tokens*split_wv->ne[1],
kv_head*ggml_row_size(split_vl->type, split_wv->ne[1]));
lctx.cache_copies[2*il+1].step = ggml_row_size(split_vl->type, split_wv->ne[1]);
} else {
// note: the V cache is transposed when not using flash attention
v_cache_view = ggml_view_2d(ctx0, split_vl, n_tokens, split_wv->ne[1],
( n_ctx)*ggml_element_size(split_vl),
(kv_head)*ggml_element_size(split_vl));
lctx.cache_copies[2*il+1].step = ggml_element_size(split_vl);
Vcur = ggml_transpose(ctx0, Vcur);
}
cb(v_cache_view, "v_cache_view", il_cb);
lctx.cache_copies[2*il+1].cpy = ggml_cpy(ctx0, Vcur, v_cache_view);
ggml_build_forward_expand(gf, lctx.cache_copies[2*il+1].cpy);
auto q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
cb(q, "q", il_cb);
auto k = ggml_view_3d(ctx0, split_kl, n_embd_head_k, n_kv, n_head_kv,
ggml_row_size(split_kl->type, n_embd_head_k)*n_head_kv, //n_embd_k_gqa),
ggml_row_size(split_kl->type, n_embd_head_k), 0);
cb(k, "k", il_cb);
auto v = ggml_view_3d(ctx0, split_vl, n_embd_head_v, n_kv, n_head_kv,
ggml_row_size(split_vl->type, split_wv->ne[1]),
ggml_row_size(split_vl->type, n_embd_head_v), 0);
cb(v, "v", il_cb);
#ifdef GGML_USE_VULKAN
constexpr bool use_f32_precision = true;
#else
constexpr bool use_f32_precision = false;
#endif
cur = ggml_flash_attn_ext(ctx0, q, k, v, KQ_mask, KQ_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
ggml_flash_attn_ext_add_sinks(cur, sinks);
if (n_swa > 0) {
((int32_t *)cur->op_params)[4] = n_swa;
}
// Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA
if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX ||
(model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 ||
model.arch == LLM_ARCH_GLM4_MOE) {
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
}
cur = ggml_reshape_2d(ctx0, cur, split_wo->ne[0], n_tokens);
cur = llm_build_lora_mm(lctx, ctx0, split_wo, cur);
if (lctx.model.arch == LLM_ARCH_GLM4 || lctx.model.arch == LLM_ARCH_GLM4_MOE) {
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
cb(cur, "kqv_wo", il_cb);
// TODO: wo_b
attn.push_back(cur);
}
if (attn.size() == 1) return attn.front();
cur = ggml_add(ctx0, attn[0], attn[1]);
cb(cur, "combine_attn", il);
for (int id = 2; id < (int)attn.size(); ++id) {
cur = ggml_add(ctx0, cur, attn[id]);
cb(cur, "combine_attn", il);
}
return cur;
}
}
auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur,
model.layers[il].wqkv, model.layers[il].bqkv,
model.layers[il].wqk, model.layers[il].bqk,
model.layers[il].wq, model.layers[il].bq, model.layers[il].wk, model.layers[il].bk, model.layers[il].wv, model.layers[il].bv,
model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, f_attn_scale, il);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, KQ_scale, cb, il, sinks, n_swa);
return cur;
}