mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-29 19:01:47 +00:00
multi_sdd: WIP
This commit is contained in:
@@ -494,6 +494,7 @@ extern "C" {
|
|||||||
GGML_OP_GROUP_NORM,
|
GGML_OP_GROUP_NORM,
|
||||||
GGML_OP_FUSED_RMS_NORM,
|
GGML_OP_FUSED_RMS_NORM,
|
||||||
GGML_OP_FUSED_MUL_UNARY,
|
GGML_OP_FUSED_MUL_UNARY,
|
||||||
|
GGML_OP_MULTI_ADD,
|
||||||
|
|
||||||
GGML_OP_MUL_MAT,
|
GGML_OP_MUL_MAT,
|
||||||
GGML_OP_MUL_MAT_ID,
|
GGML_OP_MUL_MAT_ID,
|
||||||
@@ -930,6 +931,10 @@ extern "C" {
|
|||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b);
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_multi_add(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor ** a);
|
||||||
|
|
||||||
// dst = a
|
// dst = a
|
||||||
// view(dst, nb1, nb2, nb3, offset) += b
|
// view(dst, nb1, nb2, nb3, offset) += b
|
||||||
// return dst
|
// return dst
|
||||||
|
|||||||
117
ggml/src/ggml.c
117
ggml/src/ggml.c
@@ -3338,6 +3338,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||||||
"GROUP_NORM",
|
"GROUP_NORM",
|
||||||
"FUSED_RMS_NORM",
|
"FUSED_RMS_NORM",
|
||||||
"FUSED_MUL_UNARY",
|
"FUSED_MUL_UNARY",
|
||||||
|
"MULTI_ADD",
|
||||||
|
|
||||||
"MUL_MAT",
|
"MUL_MAT",
|
||||||
"MUL_MAT_ID",
|
"MUL_MAT_ID",
|
||||||
@@ -3401,7 +3402,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||||||
"CROSS_ENTROPY_LOSS_BACK",
|
"CROSS_ENTROPY_LOSS_BACK",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
|
static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
|
||||||
|
|
||||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"none",
|
"none",
|
||||||
@@ -3430,6 +3431,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||||||
"group_norm(x)",
|
"group_norm(x)",
|
||||||
"fused_rms_norm(x)",
|
"fused_rms_norm(x)",
|
||||||
"fused_mul_unary(x)",
|
"fused_mul_unary(x)",
|
||||||
|
"x1+x2+x3+...",
|
||||||
|
|
||||||
"X*Y",
|
"X*Y",
|
||||||
"X[i]*Y",
|
"X[i]*Y",
|
||||||
@@ -3493,7 +3495,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||||||
"cross_entropy_loss_back(x,y)",
|
"cross_entropy_loss_back(x,y)",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78");
|
static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
|
||||||
|
|
||||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||||
|
|
||||||
@@ -5106,6 +5108,49 @@ struct ggml_tensor * ggml_add_inplace(
|
|||||||
return ggml_add_impl(ctx, a, b, true);
|
return ggml_add_impl(ctx, a, b, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_add
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_multi_add(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor ** a) {
|
||||||
|
|
||||||
|
bool is_node = false;
|
||||||
|
|
||||||
|
struct ggml_tensor * a_used[GGML_MAX_SRC];
|
||||||
|
int n_used = 0;
|
||||||
|
for (int i = 0; i < GGML_MAX_SRC; ++i) {
|
||||||
|
if (a[i]) {
|
||||||
|
a_used[n_used++] = a[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n_used < 2) {
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
if (n_used == 2) {
|
||||||
|
return ggml_add(ctx, a_used[0], a_used[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 1; i < n_used; ++i) {
|
||||||
|
if (!ggml_are_same_shape(a_used[i], a[0])) {
|
||||||
|
GGML_ABORT("fayal error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * result = ggml_dup_tensor(ctx, a_used[0]);
|
||||||
|
|
||||||
|
result->op = GGML_OP_MULTI_ADD;
|
||||||
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
for (int i = 1; i < n_used; ++i) {
|
||||||
|
result->src[i] = a_used[i];
|
||||||
|
}
|
||||||
|
for (int i = n_used; i < GGML_MAX_SRC; ++i) {
|
||||||
|
result->src[i] = NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_add_cast
|
// ggml_add_cast
|
||||||
|
|
||||||
static struct ggml_tensor * ggml_add_cast_impl(
|
static struct ggml_tensor * ggml_add_cast_impl(
|
||||||
@@ -10425,6 +10470,65 @@ static void ggml_compute_forward_add(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_multi_add_f32(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->nb[0] == sizeof(float));
|
||||||
|
for (int i = 0; i < GGML_MAX_SRC; ++i) {
|
||||||
|
if (dst->src[i]) {
|
||||||
|
GGML_ASSERT(ggml_are_same_shape(dst->src[i], dst));
|
||||||
|
GGML_ASSERT(dst->src[i]->nb[0] == sizeof(float));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int nr = ggml_nrows(dst);
|
||||||
|
|
||||||
|
// rows per thread
|
||||||
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
|
// row range for this thread
|
||||||
|
const int ir0 = dr*ith;
|
||||||
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
int64_t ne0 = dst->ne[0];
|
||||||
|
int64_t ne1 = dst->ne[1];
|
||||||
|
int64_t ne2 = dst->ne[2];
|
||||||
|
|
||||||
|
for (int ir = ir0; ir < ir1; ++ir) {
|
||||||
|
// src1 is broadcastable across src0 and dst in i1, i2, i3
|
||||||
|
const int64_t i3 = ir/(ne2*ne1);
|
||||||
|
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
|
||||||
|
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
||||||
|
|
||||||
|
float * dst_ptr = (float *) ((char *) dst->data + i3*dst->nb[3] + i2*dst->nb[2] + i1*dst->nb[1] );
|
||||||
|
memset(dst_ptr, 0, ne0*sizeof(float));
|
||||||
|
for (int i = 0; i < GGML_MAX_SRC; ++i) {
|
||||||
|
struct ggml_tensor * src = dst->src[i];
|
||||||
|
if (!src) continue;
|
||||||
|
const float * data = (const float *) ((const char *) src->data + i3*src->nb[3] + i2*src->nb[2] + i1*src->nb[1]);
|
||||||
|
ggml_vec_add_f32(ne0, dst_ptr, dst_ptr, data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_multi_add(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
|
||||||
|
switch (dst->type) {
|
||||||
|
case GGML_TYPE_F32: {
|
||||||
|
ggml_compute_forward_multi_add_f32(params, dst);
|
||||||
|
} break;
|
||||||
|
default: {
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_add1
|
// ggml_compute_forward_add1
|
||||||
|
|
||||||
static void ggml_compute_forward_add1_f32(
|
static void ggml_compute_forward_add1_f32(
|
||||||
@@ -18202,6 +18306,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||||||
{
|
{
|
||||||
ggml_compute_forward_add1(params, tensor);
|
ggml_compute_forward_add1(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_MULTI_ADD:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_multi_add(params, tensor);
|
||||||
|
} break;
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_acc(params, tensor);
|
ggml_compute_forward_acc(params, tensor);
|
||||||
@@ -18947,6 +19055,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||||||
{
|
{
|
||||||
GGML_ABORT("fatal error"); // TODO: implement
|
GGML_ABORT("fatal error"); // TODO: implement
|
||||||
}
|
}
|
||||||
|
case GGML_OP_MULTI_ADD:
|
||||||
|
{
|
||||||
|
GGML_ABORT("fatal error"); // TODO: implement
|
||||||
|
}
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
{
|
{
|
||||||
GGML_ABORT("fatal error"); // TODO: implement
|
GGML_ABORT("fatal error"); // TODO: implement
|
||||||
@@ -19996,6 +20108,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
case GGML_OP_ADD1:
|
case GGML_OP_ADD1:
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
|
case GGML_OP_MULTI_ADD:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
} break;
|
} break;
|
||||||
|
|||||||
@@ -8351,16 +8351,42 @@ static struct ggml_tensor * llm_build_moe_ffn(
|
|||||||
|
|
||||||
experts = ggml_mul(ctx, experts, weights);
|
experts = ggml_mul(ctx, experts, weights);
|
||||||
|
|
||||||
|
if (n_expert_used == 1) {
|
||||||
|
return ggml_cont(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0));
|
||||||
|
}
|
||||||
|
if (n_expert_used == 2) {
|
||||||
|
return ggml_add(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0),
|
||||||
|
ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], experts->nb[1]));
|
||||||
|
}
|
||||||
|
if (n_expert_used <= GGML_MAX_SRC) {
|
||||||
|
ggml_tensor * src[GGML_MAX_SRC];
|
||||||
|
for (int i = 0; i < n_expert_used; ++i) {
|
||||||
|
src[i] = ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
|
||||||
|
}
|
||||||
|
for (int i = n_expert_used; i < GGML_MAX_SRC; ++i) src[i] = nullptr;
|
||||||
|
return ggml_multi_add(ctx, src);
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
|
||||||
|
//int nloop = (n_expert_used + GGML_MAX_SRC - 1)/GGML_MAX_SRC;
|
||||||
|
|
||||||
// aggregate experts
|
// aggregate experts
|
||||||
ggml_tensor * moe_out = nullptr;
|
ggml_tensor * moe_out = nullptr;
|
||||||
|
//ggml_tensor * first_expert = nullptr;
|
||||||
for (int i = 0; i < n_expert_used; ++i) {
|
for (int i = 0; i < n_expert_used; ++i) {
|
||||||
ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens,
|
ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens,
|
||||||
experts->nb[2], i*experts->nb[1]);
|
experts->nb[2], i*experts->nb[1]);
|
||||||
|
|
||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
moe_out = cur_expert;
|
moe_out = cur_expert;
|
||||||
|
//first_expert = cur_expert;
|
||||||
|
//printf("%s: %d: %d x %d x %d x %d | %d x %d x %d x %d\n", __func__, ggml_is_contiguous(first_expert),
|
||||||
|
// (int)cur_expert->ne[0], (int)cur_expert->ne[1], (int)cur_expert->ne[2], (int)cur_expert->ne[3],
|
||||||
|
// (int)cur_expert->nb[0], (int)cur_expert->nb[1], (int)cur_expert->nb[2], (int)cur_expert->nb[3]);
|
||||||
} else {
|
} else {
|
||||||
moe_out = ggml_add(ctx, moe_out, cur_expert);
|
moe_out = ggml_add(ctx, moe_out, cur_expert);
|
||||||
|
//printf("%s: %d %d\n", __func__, ggml_is_contiguous(cur_expert), ggml_are_same_shape(cur_expert, first_expert));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -9011,6 +9037,7 @@ struct llm_build_context {
|
|||||||
// compute Q and K and RoPE them
|
// compute Q and K and RoPE them
|
||||||
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||||
if (hparams.f_attention_scale != 0) {
|
if (hparams.f_attention_scale != 0) {
|
||||||
|
// Why is hparams.f_attention_scale not simply absorbed into model.layers[il].wq ?
|
||||||
Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
|
Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
|
||||||
}
|
}
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
@@ -9062,6 +9089,7 @@ struct llm_build_context {
|
|||||||
|
|
||||||
// For Granite architecture
|
// For Granite architecture
|
||||||
if (hparams.f_residual_scale) {
|
if (hparams.f_residual_scale) {
|
||||||
|
// Why is hparams.f_residual_scale not simply absorbed into model.layers[il].wv ?
|
||||||
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
|
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -9103,6 +9131,7 @@ struct llm_build_context {
|
|||||||
|
|
||||||
// For Granite architecture
|
// For Granite architecture
|
||||||
if (hparams.f_residual_scale) {
|
if (hparams.f_residual_scale) {
|
||||||
|
// Why is hparams.f_residual_scale not simply absorbed into model.layers[il].ffn_down_exps ?
|
||||||
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
|
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -9128,6 +9157,7 @@ struct llm_build_context {
|
|||||||
|
|
||||||
// For Granite architecture
|
// For Granite architecture
|
||||||
if (hparams.f_logit_scale) {
|
if (hparams.f_logit_scale) {
|
||||||
|
// Why is hparams.f_logit_scale not simply absorbed into model.output ?
|
||||||
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
|
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user