bitnet(scale in a separate tensor): mul -> scale on the CPU

This commit is contained in:
Iwan Kawrakow
2024-06-20 08:21:25 +03:00
parent e73ae1f6d3
commit 36374ab37d
2 changed files with 29 additions and 9 deletions

17
ggml.c
View File

@@ -10156,6 +10156,23 @@ static void ggml_compute_forward_mul_f32(
const int ith = params->ith;
const int nth = params->nth;
if (ggml_nelements(dst->src[1]) == 1 && ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst) &&
dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
int64_t nelements = ggml_nelements(dst->src[0]);
int64_t n_per_thread = (nelements + nth - 1)/nth;
n_per_thread = MAX(1024, n_per_thread);
int64_t start = n_per_thread*ith;
if (start >= nelements) return;
int64_t end = MIN(nelements, start + n_per_thread);
const float * src = (const float *)dst->src[0]->data + start;
float * res = (float *)dst->data + start;
if (res != src) {
memcpy(res, src, (end - start)*sizeof(float));
}
ggml_vec_scale_f32(end - start, res, *(const float *)dst->src[1]->data);
return;
}
const int64_t nr = ggml_nrows(src0);
GGML_TENSOR_BINARY_OP_LOCALS

View File

@@ -11822,11 +11822,13 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
#define BITNET_MUL ggml_mul
// self-attention
{
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale);
Qcur = BITNET_MUL(ctx0, Qcur, model.layers[il].wq_scale);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
@@ -11835,7 +11837,7 @@ struct llm_build_context {
// B1.K
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale);
Kcur = BITNET_MUL(ctx0, Kcur, model.layers[il].wk_scale);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
@@ -11844,7 +11846,7 @@ struct llm_build_context {
// B1.V
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale);
Vcur = BITNET_MUL(ctx0, Vcur, model.layers[il].wv_scale);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -11938,7 +11940,7 @@ struct llm_build_context {
ggml_build_forward_expand(gf, cur_attn);
cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur_attn);
cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale);
cur = BITNET_MUL(ctx0, cur, model.layers[il].wo_scale);
cb(cur, "kqv_out", il);
}
@@ -11961,12 +11963,12 @@ struct llm_build_context {
cb(cur, "ffn_norm", il);
struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur);
tmp = ggml_mul(ctx0, tmp, model.layers[il].ffn_up_scale);
tmp = BITNET_MUL(ctx0, tmp, model.layers[il].ffn_up_scale);
cb(tmp, "ffn_up", il);
cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur);
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_gate_scale);
cur = BITNET_MUL(ctx0, cur, model.layers[il].ffn_gate_scale);
cb(cur, "ffn_gate", il);
@@ -11974,7 +11976,7 @@ struct llm_build_context {
cur = ggml_silu(ctx0, cur);
cb(cur, "ffn_silu", il);
cur = ggml_mul(ctx0, cur, tmp);
cur = BITNET_MUL(ctx0, cur, tmp);
cb(cur, "ffn_gate_par", il);
cur = llm_build_norm(ctx0, cur, hparams,
@@ -11983,7 +11985,7 @@ struct llm_build_context {
cb(cur, "ffn_sub_norm", il);
cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur);
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale);
cur = BITNET_MUL(ctx0, cur, model.layers[il].ffn_down_scale);
cb(cur, "ffn_down", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
@@ -12007,6 +12009,7 @@ struct llm_build_context {
ggml_build_forward_expand(gf, cur);
return gf;
}
#undef BITNET_MUL
};