diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index bc58c5ed..27974c09 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -691,9 +691,9 @@ ggml_tensor * llm_build_context::llm_build_ffn( if (ffn.size() > 2) { cur->op_params[0] = 0xff; } - if (cur->type != GGML_TYPE_F32) { - cur = ggml_cast(ctx, cur, GGML_TYPE_F32); - } + //if (cur->type != GGML_TYPE_F32) { + // cur = ggml_cast(ctx, cur, GGML_TYPE_F32); + //} return cur; } @@ -9002,6 +9002,9 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens cur = llm_build_norm(ctx0, cur, hparams, split_norm, NULL, LLM_NORM_RMS, cb, il); cb(cur, "attn_norm", il_cb); } + else if (cur->type != GGML_TYPE_F32) { + cur = ggml_cast(ctx0, cur, GGML_TYPE_F32); + } auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv(gf, cur, nullptr, nullptr, nullptr, nullptr, split_wq, nullptr, split_wk, nullptr, split_wv, nullptr, model.layers[il].attn_q_norm, model.layers[il].attn_k_norm, f_attn_scale, il_cb);