mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-05-11 00:20:19 +00:00
Revive fused delta-net
This commit is contained in:
@@ -678,6 +678,7 @@ extern "C" {
|
||||
GGML_OP_TRI,
|
||||
GGML_OP_FILL,
|
||||
GGML_OP_SOLVE_TRI,
|
||||
GGML_OP_DELTA_NET,
|
||||
|
||||
GGML_OP_MAP_UNARY,
|
||||
GGML_OP_MAP_BINARY,
|
||||
@@ -2508,6 +2509,15 @@ extern "C" {
|
||||
bool lower,
|
||||
bool uni);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_delta_net(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * q,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * g,
|
||||
struct ggml_tensor * beta,
|
||||
struct ggml_tensor * state);
|
||||
|
||||
// custom operators
|
||||
|
||||
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
||||
|
||||
@@ -55,6 +55,7 @@
|
||||
#include "ggml-cuda/hadamard.cuh"
|
||||
#include "ggml-cuda/reduce.cuh"
|
||||
#include "ggml-cuda/tri.cuh"
|
||||
#include "ggml-cuda/delta-net.cuh"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
@@ -3698,6 +3699,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
ggml_cuda_op_solve_tri(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_DELTA_NET:
|
||||
ggml_cuda_op_delta_net(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
ggml_cuda_flash_attn_ext(ctx, dst);
|
||||
break;
|
||||
@@ -4557,6 +4561,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
op->src[2]->ne[1] == op->src[0]->ne[1] &&
|
||||
op->src[1]->ne[0] == op->src[0]->ne[1] &&
|
||||
op->src[3]->ne[0] == op->src[0]->ne[2];
|
||||
case GGML_OP_DELTA_NET:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
|
||||
|
||||
1698
ggml/src/ggml-cuda/delta-net.cu
Normal file
1698
ggml/src/ggml-cuda/delta-net.cu
Normal file
File diff suppressed because it is too large
Load Diff
3
ggml/src/ggml-cuda/delta-net.cuh
Normal file
3
ggml/src/ggml-cuda/delta-net.cuh
Normal file
@@ -0,0 +1,3 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
197
ggml/src/ggml.c
197
ggml/src/ggml.c
@@ -4277,6 +4277,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"TRI",
|
||||
"FILL",
|
||||
"SOLVE_TRI",
|
||||
"DELTA_NET",
|
||||
|
||||
"MAP_UNARY",
|
||||
"MAP_BINARY",
|
||||
@@ -4299,7 +4300,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"FUSED_NORM",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 100, "GGML_OP_COUNT != 100");
|
||||
static_assert(GGML_OP_COUNT == 101, "GGML_OP_COUNT != 101");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@@ -4395,6 +4396,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"tri(x)",
|
||||
"fill(x)",
|
||||
"solve_tri(x)",
|
||||
"delta_net",
|
||||
|
||||
"f(x)",
|
||||
"f(x,y)",
|
||||
@@ -4417,7 +4419,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"norm(x,y)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 100, "GGML_OP_COUNT != 100");
|
||||
static_assert(GGML_OP_COUNT == 101, "GGML_OP_COUNT != 101");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@@ -9869,6 +9871,59 @@ struct ggml_tensor * ggml_tri(
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_delta_net(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * q,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * g,
|
||||
struct ggml_tensor * beta,
|
||||
struct ggml_tensor * state) {
|
||||
GGML_ASSERT(ggml_is_contiguous(q));
|
||||
GGML_ASSERT(ggml_is_contiguous(k));
|
||||
GGML_ASSERT(ggml_is_contiguous(v));
|
||||
GGML_ASSERT(ggml_is_contiguous(g));
|
||||
GGML_ASSERT(ggml_is_contiguous(beta));
|
||||
GGML_ASSERT(ggml_is_contiguous(state));
|
||||
|
||||
GGML_ASSERT(q->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(k->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(v->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(g->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(beta->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(state->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t S_k = q->ne[0];
|
||||
const int64_t n_tokens = q->ne[1];
|
||||
const int64_t H_k = q->ne[2];
|
||||
const int64_t n_seqs = q->ne[3];
|
||||
|
||||
const int64_t S_v = v->ne[0];
|
||||
const int64_t H_v = v->ne[2];
|
||||
|
||||
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == n_tokens && k->ne[2] == H_k && k->ne[3] == n_seqs);
|
||||
GGML_ASSERT(v->ne[1] == n_tokens && v->ne[3] == n_seqs);
|
||||
GGML_ASSERT(g->ne[0] == n_tokens && g->ne[1] == 1 && g->ne[2] == H_k && g->ne[3] == n_seqs);
|
||||
GGML_ASSERT(beta->ne[0] == 1 && beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[3] == n_seqs);
|
||||
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
|
||||
GGML_ASSERT(H_k == H_v);
|
||||
|
||||
const int64_t output_size = S_v * H_v * n_tokens * n_seqs;
|
||||
const int64_t state_size = S_v * S_v * H_v * n_seqs;
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, output_size + state_size);
|
||||
|
||||
result->op = GGML_OP_DELTA_NET;
|
||||
result->src[0] = q;
|
||||
result->src[1] = k;
|
||||
result->src[2] = v;
|
||||
result->src[3] = g;
|
||||
result->src[4] = beta;
|
||||
result->src[5] = state;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_fill
|
||||
|
||||
static struct ggml_tensor * ggml_fill_impl(
|
||||
@@ -22476,6 +22531,138 @@ static void ggml_compute_forward_solve_tri(const struct ggml_compute_params * pa
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_delta_net
|
||||
|
||||
static void ggml_compute_forward_delta_net_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
const struct ggml_tensor * src2 = dst->src[2];
|
||||
const struct ggml_tensor * src3 = dst->src[3];
|
||||
const struct ggml_tensor * src4 = dst->src[4];
|
||||
const struct ggml_tensor * src5 = dst->src[5];
|
||||
|
||||
const int64_t head_dim = src0->ne[0];
|
||||
const int64_t n_tokens = src0->ne[1];
|
||||
const int64_t n_heads = src0->ne[2];
|
||||
const int64_t n_seqs = src0->ne[3];
|
||||
|
||||
const int64_t output_size = head_dim * n_tokens * n_heads * n_seqs;
|
||||
|
||||
const float * q_data = (const float *) src0->data;
|
||||
const float * k_data = (const float *) src1->data;
|
||||
const float * v_data = (const float *) src2->data;
|
||||
const float * g_data = (const float *) src3->data;
|
||||
const float * beta_data = (const float *) src4->data;
|
||||
const float * state_in = (const float *) src5->data;
|
||||
float * out_data = (float *) dst->data;
|
||||
float * state_out = out_data + output_size;
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t total_heads = n_heads * n_seqs;
|
||||
const int64_t heads_per_thread = (total_heads + nth - 1) / nth;
|
||||
const int64_t h_start = ith * heads_per_thread;
|
||||
const int64_t h_end = (h_start + heads_per_thread < total_heads) ? h_start + heads_per_thread : total_heads;
|
||||
|
||||
const float eps = 1e-12f;
|
||||
const float scale = 1.0f / sqrtf((float) head_dim);
|
||||
|
||||
float * v_new_buf = (float *) malloc(head_dim * sizeof(float));
|
||||
if (!v_new_buf) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (int64_t h_idx = h_start; h_idx < h_end; ++h_idx) {
|
||||
const int64_t batch_idx = h_idx / n_heads;
|
||||
const int64_t head_idx = h_idx % n_heads;
|
||||
|
||||
const int64_t qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens);
|
||||
const int64_t qkv_token_stride = head_dim;
|
||||
const int64_t g_head_offset = batch_idx * (n_tokens * n_heads) + head_idx * n_tokens;
|
||||
const int64_t state_head_offset = batch_idx * (head_dim * head_dim * n_heads) + head_idx * (head_dim * head_dim);
|
||||
const int64_t out_head_offset = batch_idx * (head_dim * n_heads * n_tokens) + head_idx * head_dim;
|
||||
const int64_t out_token_stride = head_dim * n_heads;
|
||||
|
||||
for (int64_t i = 0; i < head_dim * head_dim; ++i) {
|
||||
state_out[state_head_offset + i] = state_in[state_head_offset + i];
|
||||
}
|
||||
|
||||
float * state = state_out + state_head_offset;
|
||||
|
||||
for (int64_t t = 0; t < n_tokens; ++t) {
|
||||
const float * q_t = q_data + qkv_head_offset + t * qkv_token_stride;
|
||||
const float * k_t = k_data + qkv_head_offset + t * qkv_token_stride;
|
||||
const float * v_t = v_data + qkv_head_offset + t * qkv_token_stride;
|
||||
|
||||
const float g_val = g_data[g_head_offset + t];
|
||||
const float beta_raw = beta_data[g_head_offset + t];
|
||||
|
||||
float q_norm_sq = 0.0f;
|
||||
float k_norm_sq = 0.0f;
|
||||
for (int64_t i = 0; i < head_dim; ++i) {
|
||||
q_norm_sq += q_t[i] * q_t[i];
|
||||
k_norm_sq += k_t[i] * k_t[i];
|
||||
}
|
||||
const float q_norm_inv = 1.0f / sqrtf(q_norm_sq + eps);
|
||||
const float k_norm_inv = 1.0f / sqrtf(k_norm_sq + eps);
|
||||
|
||||
const float beta_val = 1.0f / (1.0f + expf(-beta_raw));
|
||||
const float decay = expf(fminf(g_val, 50.0f));
|
||||
|
||||
float attn_score = 0.0f;
|
||||
for (int64_t i = 0; i < head_dim; ++i) {
|
||||
attn_score += (k_t[i] * k_norm_inv) * (q_t[i] * q_norm_inv * scale);
|
||||
}
|
||||
|
||||
float * out_t = out_data + out_head_offset + t * out_token_stride;
|
||||
|
||||
for (int64_t row = 0; row < head_dim; ++row) {
|
||||
float v_prime = 0.0f;
|
||||
float out_val = 0.0f;
|
||||
|
||||
for (int64_t col = 0; col < head_dim; ++col) {
|
||||
const float k_col = k_t[col] * k_norm_inv;
|
||||
const float q_col = q_t[col] * q_norm_inv * scale;
|
||||
const float s = state[row + col * head_dim];
|
||||
|
||||
v_prime += s * k_col * beta_val * decay;
|
||||
out_val += s * q_col * decay;
|
||||
}
|
||||
|
||||
const float v_new = v_t[row] * beta_val - v_prime;
|
||||
v_new_buf[row] = v_new;
|
||||
out_t[row] = out_val + v_new * attn_score;
|
||||
}
|
||||
|
||||
for (int64_t col = 0; col < head_dim; ++col) {
|
||||
const float k_col = k_t[col] * k_norm_inv;
|
||||
for (int64_t row = 0; row < head_dim; ++row) {
|
||||
float s = state[row + col * head_dim];
|
||||
s = decay * s + v_new_buf[row] * k_col;
|
||||
state[row + col * head_dim] = fminf(fmaxf(s, -1e6f), 1e6f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
free(v_new_buf);
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_delta_net(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (dst->src[0]->type) {
|
||||
case GGML_TYPE_F32:
|
||||
ggml_compute_forward_delta_net_f32(params, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_win_part
|
||||
|
||||
static void ggml_compute_forward_win_part_f32(
|
||||
@@ -24202,6 +24389,10 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
|
||||
{
|
||||
ggml_compute_forward_solve_tri(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_DELTA_NET:
|
||||
{
|
||||
ggml_compute_forward_delta_net(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_WIN_PART:
|
||||
{
|
||||
ggml_compute_forward_win_part(params, tensor);
|
||||
@@ -25260,6 +25451,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
case GGML_OP_TRI:
|
||||
case GGML_OP_FILL:
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
case GGML_OP_DELTA_NET:
|
||||
{
|
||||
GGML_ABORT("fatal error"); // TODO: not implemented
|
||||
}
|
||||
@@ -25990,6 +26182,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_FUSED_UP_GATE:
|
||||
case GGML_OP_OUT_PROD:
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
case GGML_OP_DELTA_NET:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
|
||||
@@ -372,6 +372,76 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_delta_net_autoregressiv
|
||||
return {core_attn_out, state};
|
||||
}
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_fused_delta_net(ggml_context * ctx0,
|
||||
ggml_tensor * q, ggml_tensor * k, ggml_tensor * v,
|
||||
ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state,
|
||||
int il, const llm_build_cb & cb) {
|
||||
|
||||
const int64_t S_k = q->ne[0];
|
||||
const int64_t H_k = q->ne[1];
|
||||
const int64_t n_tokens = q->ne[2];
|
||||
const int64_t n_seqs = q->ne[3];
|
||||
|
||||
const int64_t S_v = v->ne[0];
|
||||
const int64_t H_v = v->ne[1];
|
||||
|
||||
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
|
||||
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
|
||||
GGML_ASSERT(v->ne[2] == n_tokens);
|
||||
GGML_ASSERT(k->ne[2] == n_tokens);
|
||||
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
||||
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
|
||||
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
|
||||
GGML_ASSERT(H_k == H_v);
|
||||
|
||||
cb(q, "q_in", il);
|
||||
cb(k, "k_in", il);
|
||||
cb(v, "v_in", il);
|
||||
cb(beta, "beta_in", il);
|
||||
cb(g, "g_in", il);
|
||||
cb(state,"state_in", il);
|
||||
|
||||
q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
|
||||
k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
|
||||
v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
|
||||
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
|
||||
beta = ggml_cont_4d(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3), 1, n_tokens, H_k, n_seqs);
|
||||
|
||||
ggml_tensor * state_flat = ggml_reshape_4d(ctx0, state, S_v, S_v * H_v, 1, n_seqs);
|
||||
if (!ggml_is_contiguous(state_flat)) {
|
||||
state_flat = ggml_cont_4d(ctx0, state_flat, S_v, S_v * H_v, 1, n_seqs);
|
||||
}
|
||||
|
||||
cb(q, "q_fused", il);
|
||||
cb(k, "k_fused", il);
|
||||
cb(v, "v_fused", il);
|
||||
cb(g, "g_fused", il);
|
||||
cb(beta, "beta_fused", il);
|
||||
cb(state_flat,"state_fused", il);
|
||||
|
||||
ggml_tensor * fused_result = ggml_delta_net(ctx0, q, k, v, g, beta, state_flat);
|
||||
cb(fused_result, "delta_net_fused_raw", il);
|
||||
|
||||
const int64_t output_size = S_v * H_v * n_tokens * n_seqs;
|
||||
const int64_t state_size = S_v * S_v * H_v * n_seqs;
|
||||
|
||||
ggml_tensor * output_tokens = ggml_view_4d(ctx0, fused_result,
|
||||
S_v, H_v, n_tokens, n_seqs,
|
||||
ggml_row_size(fused_result->type, S_v),
|
||||
ggml_row_size(fused_result->type, S_v * H_v),
|
||||
ggml_row_size(fused_result->type, S_v * H_v * n_tokens), 0);
|
||||
output_tokens = ggml_cont_4d(ctx0, output_tokens, S_v, H_v, n_tokens, n_seqs);
|
||||
|
||||
ggml_tensor * new_state_flat = ggml_view_1d(ctx0, fused_result, state_size,
|
||||
output_size * ggml_element_size(fused_result));
|
||||
ggml_tensor * new_state = ggml_reshape_4d(ctx0, new_state_flat, S_v, S_v, H_v, n_seqs);
|
||||
|
||||
cb(output_tokens, "output_tokens", il);
|
||||
cb(new_state, "new_state", il);
|
||||
|
||||
return {output_tokens, new_state};
|
||||
}
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_qkvz(ggml_context * ctx0, ggml_tensor * input, int il, const llm_build_cb & cb) const {
|
||||
auto & model = lctx.model;
|
||||
const int64_t n_tok = input->ne[1];
|
||||
@@ -610,7 +680,8 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
|
||||
GGML_ASSERT(diag_mask != nullptr);
|
||||
|
||||
attn_out = n_tok == 1
|
||||
? build_delta_net_autoregressive(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb)
|
||||
//? build_delta_net_autoregressive(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb)
|
||||
? build_fused_delta_net(ctx0, q_conv, k_conv, v_conv, gate, beta, state, il, cb)
|
||||
: build_delta_net_chunking(ctx0, q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il, cb);
|
||||
ggml_tensor * output = attn_out.first;
|
||||
ggml_tensor * new_state = attn_out.second;
|
||||
|
||||
@@ -19,6 +19,11 @@ struct delta_net {
|
||||
ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state,
|
||||
int il, const llm_build_cb & cb);
|
||||
|
||||
static std::pair<ggml_tensor *, ggml_tensor *> build_fused_delta_net(ggml_context * ctx0,
|
||||
ggml_tensor * q, ggml_tensor * k, ggml_tensor * v,
|
||||
ggml_tensor * g, ggml_tensor * beta, ggml_tensor * state,
|
||||
int il, const llm_build_cb & cb);
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> build_qkvz(ggml_context * ctx0, ggml_tensor * input, int il, const llm_build_cb & cb) const;
|
||||
|
||||
ggml_tensor * build_layer_attn_linear_core(ggml_context * ctx0, ggml_cgraph * gf,
|
||||
|
||||
Reference in New Issue
Block a user