mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-27 09:53:40 +00:00
Add GGML_OP_REDUCE
This commit is contained in:
@@ -4291,9 +4291,11 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"CROSS_ENTROPY_LOSS_BACK",
|
||||
|
||||
"GLU",
|
||||
|
||||
"REDUCE",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 92, "GGML_OP_COUNT != 92");
|
||||
static_assert(GGML_OP_COUNT == 93, "GGML_OP_COUNT != 93");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@@ -4398,10 +4400,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"cross_entropy_loss(x,y)",
|
||||
"cross_entropy_loss_back(x,y)",
|
||||
|
||||
"glu(x),"
|
||||
"glu(x),",
|
||||
|
||||
"reduce(x1,x2,...)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 92, "GGML_OP_COUNT != 92");
|
||||
static_assert(GGML_OP_COUNT == 93, "GGML_OP_COUNT != 93");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@@ -5299,6 +5303,24 @@ static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml
|
||||
return obj_new;
|
||||
}
|
||||
|
||||
ggml_split_tensor_t * ggml_new_split(
|
||||
struct ggml_context * ctx,
|
||||
int n_device,
|
||||
int split_dim,
|
||||
struct ggml_tensor * tensor) {
|
||||
size_t size = sizeof(int) + sizeof(int) + sizeof(struct ggml_tensor *) + sizeof(struct ggml_tensor**) + n_device*sizeof(struct ggml_tensor*);
|
||||
struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TYPE_EXTRA, size);
|
||||
ggml_split_tensor_t * result = (ggml_split_tensor_t *)((char *)ctx->mem_buffer + obj_new->offs);
|
||||
result->n_device = n_device;
|
||||
result->split_dim = split_dim;
|
||||
result->tensor = tensor;
|
||||
result->splits = (struct ggml_tensor**)(result->tensor + 1);
|
||||
for (int i = 0; i < n_device; ++i) {
|
||||
result->splits[i] = NULL;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static struct ggml_tensor * ggml_new_tensor_impl(
|
||||
struct ggml_context * ctx,
|
||||
enum ggml_type type,
|
||||
@@ -6060,6 +6082,57 @@ struct ggml_tensor * ggml_dup_inplace(
|
||||
return ggml_dup_impl(ctx, a, true);
|
||||
}
|
||||
|
||||
static struct ggml_tensor * ggml_reduce_impl(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_op op,
|
||||
bool inplace) {
|
||||
GGML_ASSERT(op == GGML_OP_ADD); // the only op we currently support
|
||||
GGML_ASSERT(a->extra);
|
||||
ggml_split_tensor_t * extra = (ggml_split_tensor_t *)a->extra;
|
||||
GGML_ASSERT(extra->n_device > 1);
|
||||
GGML_ASSERT(extra->splits);
|
||||
int nhave = 0;
|
||||
for (int j = 0; j < extra->n_device; ++j) {
|
||||
if (extra->splits[j]) ++nhave;
|
||||
}
|
||||
GGML_ASSERT(nhave > 1);
|
||||
|
||||
struct ggml_tensor * result;
|
||||
if (inplace) {
|
||||
result = ggml_view_tensor(ctx, a);
|
||||
result->src[0] = a;
|
||||
result->extra = a->extra;
|
||||
} else {
|
||||
result = ggml_new_tensor_4d(ctx, a->type, a->ne[0], a->ne[1], a->ne[2], a->ne[3]);
|
||||
ggml_split_tensor_t * new_extra = ggml_new_split(ctx, extra->n_device, extra->split_dim, result);
|
||||
result->extra = new_extra;
|
||||
for (int j = 0; j < extra->n_device; ++j) {
|
||||
if (extra->splits[j]) {
|
||||
new_extra->splits[j] = ggml_dup_tensor(ctx, extra->splits[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
result->op = GGML_OP_REDUCE;
|
||||
result->op_params[0] = (int32_t)op;
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_reduce(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_op op) {
|
||||
return ggml_reduce_impl(ctx, a, op, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_reduce_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_op op) {
|
||||
return ggml_reduce_impl(ctx, a, op, true);
|
||||
}
|
||||
|
||||
|
||||
// ggml_add
|
||||
|
||||
static struct ggml_tensor * ggml_add_impl(
|
||||
@@ -22679,6 +22752,10 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
|
||||
{
|
||||
ggml_compute_forward_dup(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_REDUCE:
|
||||
{
|
||||
GGML_ABORT("Fatal error"); // TODO
|
||||
}
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
ggml_compute_forward_add(params, tensor);
|
||||
@@ -23358,6 +23435,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_REDUCE:
|
||||
{
|
||||
GGML_ABORT("Fatal error"); // TODO
|
||||
}
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
if (src0->grad) {
|
||||
|
||||
Reference in New Issue
Block a user