mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-29 10:51:51 +00:00
Fused delta-net (#1315)
* Revive fused delta-net * Add command line argument for fused delta net * Simplify/improve CUDA delta-net * Add -fdn to llama-bench * More CUDA fused delta net optimizations * CPU optimizations * Much faster fused delta-net on the CPU It seems it is faster than the chunked implementation! * Change meaning of fdn from bool flag to threshold value * Use eps = 1e-6 * Give some nodes a name
This commit is contained in:
@@ -678,6 +678,7 @@ extern "C" {
|
||||
GGML_OP_TRI,
|
||||
GGML_OP_FILL,
|
||||
GGML_OP_SOLVE_TRI,
|
||||
GGML_OP_DELTA_NET,
|
||||
|
||||
GGML_OP_MAP_UNARY,
|
||||
GGML_OP_MAP_BINARY,
|
||||
@@ -2508,6 +2509,15 @@ extern "C" {
|
||||
bool lower,
|
||||
bool uni);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_delta_net(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * q,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * g,
|
||||
struct ggml_tensor * beta,
|
||||
struct ggml_tensor * state);
|
||||
|
||||
// custom operators
|
||||
|
||||
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
||||
|
||||
@@ -55,6 +55,7 @@
|
||||
#include "ggml-cuda/hadamard.cuh"
|
||||
#include "ggml-cuda/reduce.cuh"
|
||||
#include "ggml-cuda/tri.cuh"
|
||||
#include "ggml-cuda/delta-net.cuh"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
@@ -3698,6 +3699,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
ggml_cuda_op_solve_tri(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_DELTA_NET:
|
||||
ggml_cuda_op_delta_net(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
ggml_cuda_flash_attn_ext(ctx, dst);
|
||||
break;
|
||||
@@ -4557,6 +4561,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
op->src[2]->ne[1] == op->src[0]->ne[1] &&
|
||||
op->src[1]->ne[0] == op->src[0]->ne[1] &&
|
||||
op->src[3]->ne[0] == op->src[0]->ne[2];
|
||||
case GGML_OP_DELTA_NET:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
|
||||
|
||||
485
ggml/src/ggml-cuda/delta-net.cu
Normal file
485
ggml/src/ggml-cuda/delta-net.cu
Normal file
@@ -0,0 +1,485 @@
|
||||
#include "common.cuh"
|
||||
#include "delta-net.cuh"
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
|
||||
// Delta Net Linear Attention Kernel for Qwen3-Next (HEAD_DIM=128)
|
||||
// State layout: [S_v, S_v*H_v, 1, n_seqs] (column-major)
|
||||
|
||||
__device__ __forceinline__ float sigmoid_f(float x) {
|
||||
return 1.0f / (1.0f + expf(-x));
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
__device__ __forceinline__ float reduce_sum(float x, float * s) {
|
||||
x = warp_reduce_sum(x);
|
||||
if constexpr (block_size > WARP_SIZE) {
|
||||
//__shared__ float s[block_size/WARP_SIZE];
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s[warp_id] = x;
|
||||
}
|
||||
__syncthreads();
|
||||
x = lane_id < block_size/WARP_SIZE ? s[lane_id] : 0.0f;
|
||||
x = warp_reduce_sum(x);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
template <int HEAD_DIM, int block_size>
|
||||
__global__ void delta_net_recurrent_f32(
|
||||
const float * __restrict__ q, // [HEAD_DIM, n_tokens, n_heads, n_seqs]
|
||||
const float * __restrict__ k, // [HEAD_DIM, n_tokens, n_heads, n_seqs]
|
||||
const float * __restrict__ v, // [HEAD_DIM, n_tokens, n_heads, n_seqs]
|
||||
const float * __restrict__ g, // [n_tokens, 1, n_heads, n_seqs]
|
||||
const float * __restrict__ beta_in, // [1, n_tokens, n_heads, n_seqs]
|
||||
const float * __restrict__ state_in, // [HEAD_DIM, HEAD_DIM*n_heads, 1, n_seqs]
|
||||
float * __restrict__ dst, // output + new_state concatenated
|
||||
const int64_t n_heads,
|
||||
const int64_t n_tokens,
|
||||
const int64_t n_seqs,
|
||||
const int64_t output_offset, // offset where state starts in output
|
||||
const float eps) {
|
||||
const int batch_idx = blockIdx.x / n_heads;
|
||||
const int head_idx = blockIdx.x % n_heads;
|
||||
const int tid = threadIdx.x;
|
||||
const int warp_id = tid / WARP_SIZE; // 0-7 for 256 threads
|
||||
const int lane_id = tid % WARP_SIZE; // 0-31
|
||||
constexpr int NUM_WARPS = block_size/WARP_SIZE;
|
||||
|
||||
// Strides for input tensors (column-major)
|
||||
// Q/K/V: [HEAD_DIM, n_tokens, n_heads, n_seqs]
|
||||
const int64_t qkv_stride_token = HEAD_DIM;
|
||||
const int64_t qkv_stride_head = HEAD_DIM * n_tokens;
|
||||
const int64_t qkv_stride_batch = HEAD_DIM * n_tokens * n_heads;
|
||||
|
||||
// G/Beta: [n_tokens, 1, n_heads, n_seqs] / [1, n_tokens, n_heads, n_seqs]
|
||||
const int64_t g_stride_head = n_tokens;
|
||||
const int64_t g_stride_batch = n_tokens * n_heads;
|
||||
|
||||
// State: [HEAD_DIM, HEAD_DIM*n_heads, 1, n_seqs]
|
||||
// For head h: columns h*HEAD_DIM to (h+1)*HEAD_DIM
|
||||
// state[row, col] for head h = state[row, h*HEAD_DIM + col]
|
||||
// Linear index: row + (h*HEAD_DIM + col) * HEAD_DIM = row + h*HEAD_DIM^2 + col*HEAD_DIM
|
||||
const int64_t state_head_offset = head_idx * HEAD_DIM * HEAD_DIM;
|
||||
const int64_t state_batch_stride = HEAD_DIM * HEAD_DIM * n_heads;
|
||||
|
||||
// Pointers for this batch/head
|
||||
const float * q_ptr = q + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head;
|
||||
const float * k_ptr = k + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head;
|
||||
const float * v_ptr = v + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head;
|
||||
const float * g_ptr = g + batch_idx * g_stride_batch + head_idx * g_stride_head;
|
||||
const float * beta_ptr = beta_in + batch_idx * g_stride_batch + head_idx * g_stride_head;
|
||||
const float * state_src = state_in + batch_idx * state_batch_stride + state_head_offset;
|
||||
|
||||
// Output layout: [head_v_dim, num_v_heads, n_seq_tokens, n_seqs]
|
||||
// For [dim, head, token, batch]: index = dim + head*S_v + token*S_v*H_v + batch*S_v*H_v*n_tokens
|
||||
float * out_base = dst + batch_idx * (HEAD_DIM * n_heads * n_tokens) + head_idx * HEAD_DIM;
|
||||
const int64_t out_token_stride = HEAD_DIM * n_heads; // stride between tokens
|
||||
float * state_dst = dst + output_offset + batch_idx * state_batch_stride + state_head_offset;
|
||||
|
||||
// Shared memory for current token's Q, K, V (normalized), and intermediate results
|
||||
extern __shared__ float smem[];
|
||||
float * sQ = smem; // HEAD_DIM
|
||||
float * sK = sQ + HEAD_DIM; // HEAD_DIM
|
||||
float * sV = sK + HEAD_DIM; // HEAD_DIM
|
||||
float * sKBeta = sV + HEAD_DIM; // HEAD_DIM (plain k for state update)
|
||||
float * sVBeta = sKBeta + HEAD_DIM; // HEAD_DIM (v * sigmoid(beta))
|
||||
float * sOut = sVBeta + HEAD_DIM; // HEAD_DIM
|
||||
float * sKCumdecay = sOut + HEAD_DIM; // HEAD_DIM (k * sigmoid(beta) * exp(g))
|
||||
float * sVNew = sKCumdecay + HEAD_DIM; // HEAD_DIM (v_beta - v_prime)
|
||||
|
||||
const float scale = rsqrtf((float)HEAD_DIM);
|
||||
|
||||
__shared__ float sum_helper[block_size/WARP_SIZE];
|
||||
|
||||
// Copy initial state to output buffer (will be updated in place)
|
||||
for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim.x) {
|
||||
state_dst[i] = state_src[i];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Process each token sequentially
|
||||
for (int64_t t = 0; t < n_tokens; t++) {
|
||||
|
||||
float q_sq = 0.0f;
|
||||
float k_sq = 0.0f;
|
||||
for (int i = tid; i < HEAD_DIM; i += blockDim.x) {
|
||||
sQ[i] = q_ptr[t * qkv_stride_token + i];
|
||||
sK[i] = k_ptr[t * qkv_stride_token + i];
|
||||
sV[i] = v_ptr[t * qkv_stride_token + i];
|
||||
q_sq += sQ[i] * sQ[i];
|
||||
k_sq += sK[i] * sK[i];
|
||||
}
|
||||
|
||||
q_sq = reduce_sum<block_size>(q_sq, sum_helper);
|
||||
k_sq = reduce_sum<block_size>(k_sq, sum_helper);
|
||||
|
||||
float q_norm = rsqrtf(q_sq + eps);
|
||||
float k_norm = rsqrtf(k_sq + eps);
|
||||
|
||||
float beta_val = sigmoid_f(beta_ptr[t]);
|
||||
float decay = expf(fminf(g_ptr[t], 50.0f));
|
||||
|
||||
float sum = 0;
|
||||
for (int i = tid; i < HEAD_DIM; i += blockDim.x) {
|
||||
sQ[i] = sQ[i] * q_norm * scale;
|
||||
sK[i] = sK[i] * k_norm;
|
||||
sKBeta[i] = sK[i];
|
||||
sVBeta[i] = sV[i] * beta_val;
|
||||
sKCumdecay[i] = sK[i] * beta_val * decay;
|
||||
sum += sK[i] * sQ[i];
|
||||
}
|
||||
float attn_score = reduce_sum<block_size>(sum, sum_helper);
|
||||
//__syncthreads();
|
||||
|
||||
for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) {
|
||||
float sum1 = 0.0f;
|
||||
float sum2 = 0.0f;
|
||||
#pragma unroll
|
||||
for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) {
|
||||
float sval = state_dst[row_out + col * HEAD_DIM];
|
||||
sum1 += sval * sKCumdecay[col];
|
||||
sum2 += sval * sQ[col];
|
||||
}
|
||||
sum1 = warp_reduce_sum(sum1);
|
||||
sum2 = warp_reduce_sum(sum2);
|
||||
if (lane_id == 0) {
|
||||
sVNew[row_out] = sVBeta[row_out] - sum1;
|
||||
float v_attn = sVNew[row_out] * attn_score;
|
||||
//sOut[row_out] = sum2 * decay + v_attn;
|
||||
out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int out_dim = warp_id; out_dim < HEAD_DIM; out_dim += NUM_WARPS) {
|
||||
#pragma unroll
|
||||
for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) {
|
||||
float state_val = state_dst[row + out_dim * HEAD_DIM];
|
||||
float safe_decay = decay;
|
||||
if (isnan(safe_decay) || isinf(safe_decay)) {
|
||||
safe_decay = 1.0f;
|
||||
}
|
||||
float new_state_val = safe_decay * state_val + sVNew[row] * sKBeta[out_dim];
|
||||
new_state_val = fminf(fmaxf(new_state_val, -1e6f), 1e6f);
|
||||
state_dst[row + out_dim * HEAD_DIM] = new_state_val;
|
||||
}
|
||||
}
|
||||
if (t < n_tokens - 1) {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Generic kernel that handles any HEAD_DIM at runtime (slower but flexible)
|
||||
__global__ void delta_net_recurrent_generic_f32(
|
||||
const float * __restrict__ q,
|
||||
const float * __restrict__ k,
|
||||
const float * __restrict__ v,
|
||||
const float * __restrict__ g,
|
||||
const float * __restrict__ beta_in,
|
||||
const float * __restrict__ state_in,
|
||||
float * __restrict__ dst,
|
||||
const int64_t head_dim,
|
||||
const int64_t n_tokens,
|
||||
const int64_t n_heads,
|
||||
const int64_t n_seqs,
|
||||
const int64_t output_offset,
|
||||
const float eps) {
|
||||
const int batch_idx = blockIdx.x / n_heads;
|
||||
const int head_idx = blockIdx.x % n_heads;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
// Strides (column-major)
|
||||
const int64_t qkv_stride_token = head_dim;
|
||||
const int64_t qkv_stride_head = head_dim * n_tokens;
|
||||
const int64_t qkv_stride_batch = head_dim * n_tokens * n_heads;
|
||||
|
||||
const int64_t g_stride_head = n_tokens;
|
||||
const int64_t g_stride_batch = n_tokens * n_heads;
|
||||
|
||||
const int64_t state_head_offset = head_idx * head_dim * head_dim;
|
||||
const int64_t state_batch_stride = head_dim * head_dim * n_heads;
|
||||
|
||||
// Pointers
|
||||
const float * q_ptr = q + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head;
|
||||
const float * k_ptr = k + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head;
|
||||
const float * v_ptr = v + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head;
|
||||
const float * g_ptr = g + batch_idx * g_stride_batch + head_idx * g_stride_head;
|
||||
const float * beta_ptr = beta_in + batch_idx * g_stride_batch + head_idx * g_stride_head;
|
||||
const float * state_src = state_in + batch_idx * state_batch_stride + state_head_offset;
|
||||
|
||||
// Output layout: [head_v_dim, num_v_heads, n_seq_tokens, n_seqs]
|
||||
float * out_base = dst + batch_idx * (head_dim * n_heads * n_tokens) + head_idx * head_dim;
|
||||
const int64_t out_token_stride = head_dim * n_heads;
|
||||
float * state_dst = dst + output_offset + batch_idx * state_batch_stride + state_head_offset;
|
||||
|
||||
// Shared memory for scalars (outside loop)
|
||||
__shared__ float shared_g_val, shared_beta_val, shared_decay, shared_attn_score;
|
||||
|
||||
// Dynamic shared memory
|
||||
extern __shared__ float smem[];
|
||||
float * sQ = smem;
|
||||
float * sK = sQ + head_dim;
|
||||
float * sV = sK + head_dim;
|
||||
float * sKBeta = sV + head_dim; // plain k for state update
|
||||
float * sVBeta = sKBeta + head_dim; // v * sigmoid(beta)
|
||||
float * sOut = sVBeta + head_dim;
|
||||
float * sKCumdecay = sOut + head_dim; // k * sigmoid(beta) * exp(g)
|
||||
float * sVPrime = sKCumdecay + head_dim; // state @ k_cumdecay
|
||||
float * sVNew = sVPrime + head_dim; // v_beta - v_prime
|
||||
float * sNorm = sVNew + head_dim;
|
||||
|
||||
const float scale = rsqrtf((float)head_dim);
|
||||
|
||||
// Copy initial state to output buffer
|
||||
for (int i = tid; i < head_dim * head_dim; i += blockDim.x) {
|
||||
int col = i / head_dim;
|
||||
int row = i % head_dim;
|
||||
state_dst[row + col * head_dim] = state_src[row + col * head_dim];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Process each token
|
||||
for (int64_t t = 0; t < n_tokens; t++) {
|
||||
if (tid < 2) sNorm[tid] = 0.0f;
|
||||
__syncthreads();
|
||||
|
||||
// Load Q, K, V
|
||||
for (int i = tid; i < head_dim; i += blockDim.x) {
|
||||
sQ[i] = q_ptr[t * qkv_stride_token + i];
|
||||
sK[i] = k_ptr[t * qkv_stride_token + i];
|
||||
sV[i] = v_ptr[t * qkv_stride_token + i];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// L2 normalize Q and K
|
||||
float q_sq = 0.0f, k_sq = 0.0f;
|
||||
for (int i = tid; i < head_dim; i += blockDim.x) {
|
||||
q_sq += sQ[i] * sQ[i];
|
||||
k_sq += sK[i] * sK[i];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) {
|
||||
q_sq += __shfl_xor_sync(0xffffffff, q_sq, offset);
|
||||
k_sq += __shfl_xor_sync(0xffffffff, k_sq, offset);
|
||||
}
|
||||
|
||||
if (tid % WARP_SIZE == 0) {
|
||||
atomicAdd(&sNorm[0], q_sq);
|
||||
atomicAdd(&sNorm[1], k_sq);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float q_norm = rsqrtf(sNorm[0] + eps);
|
||||
float k_norm = rsqrtf(sNorm[1] + eps);
|
||||
|
||||
for (int i = tid; i < head_dim; i += blockDim.x) {
|
||||
sQ[i] *= q_norm * scale;
|
||||
sK[i] *= k_norm;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Load g and beta, compute decay
|
||||
if (tid == 0) {
|
||||
shared_g_val = g_ptr[t];
|
||||
shared_beta_val = sigmoid_f(beta_ptr[t]);
|
||||
shared_decay = expf(fminf(shared_g_val, 50.0f));
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float beta_val = shared_beta_val;
|
||||
float decay = shared_decay;
|
||||
|
||||
// Compute k_beta, v_beta, k_cumdecay
|
||||
for (int i = tid; i < head_dim; i += blockDim.x) {
|
||||
sKBeta[i] = sK[i];
|
||||
sVBeta[i] = sV[i] * beta_val;
|
||||
sKCumdecay[i] = sK[i] * beta_val * decay;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Compute v_prime = state @ k_cumdecay
|
||||
for (int row_out = tid; row_out < head_dim; row_out += blockDim.x) {
|
||||
float v_prime_val = 0.0f;
|
||||
for (int col = 0; col < head_dim; col++) {
|
||||
// Access state[row_out, col] = state_dst[row_out + col * head_dim] for state @ k
|
||||
v_prime_val += state_dst[row_out + col * head_dim] * sKCumdecay[col];
|
||||
}
|
||||
sVPrime[row_out] = v_prime_val;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Compute v_new = v_beta - v_prime (the value residual)
|
||||
for (int i = tid; i < head_dim; i += blockDim.x) {
|
||||
sVNew[i] = sVBeta[i] - sVPrime[i];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Compute attn_score = dot(k, q) (L2 normalized vectors)
|
||||
if (tid == 0) {
|
||||
float dot_sum = 0.0f;
|
||||
for (int i = 0; i < head_dim; i++) {
|
||||
dot_sum += sK[i] * sQ[i];
|
||||
}
|
||||
shared_attn_score = dot_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Compute output: o[t] = attn_inter + v_attn
|
||||
// attn_inter = state @ (q * exp(g)) = sum_col(state[row_out, col] * q[col] * exp(g))
|
||||
// The decomposed path uses: attn_inter = ggml_mul_mat(state_t, q_g_exp)
|
||||
// Since ggml_mul_mat(A,B) = A^T @ B, attn_inter = state_t^T @ q_g_exp = state @ (q * exp(g))
|
||||
for (int row_out = tid; row_out < head_dim; row_out += blockDim.x) {
|
||||
float attn_inter = 0.0f;
|
||||
|
||||
for (int col = 0; col < head_dim; col++) {
|
||||
// Access state[row_out, col] = state_dst[row_out + col * head_dim] for state @ q
|
||||
float state_val = state_dst[row_out + col * head_dim];
|
||||
attn_inter += sQ[col] * decay * state_val;
|
||||
}
|
||||
|
||||
// v_attn = v_new * attn_score
|
||||
float v_attn = sVNew[row_out] * shared_attn_score;
|
||||
|
||||
// Output = attn_inter + v_attn (correct DeltaNet formula)
|
||||
sOut[row_out] = attn_inter + v_attn;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Update state: state_new = decay * state + outer(v_new, k)
|
||||
// Fixed: outer product orientation matches decomposed: state[v_idx, k_idx] += v_new[v_idx] * k[k_idx]
|
||||
// Uses transposed indexing: state_dst[row + out_dim * head_dim] = state[row][out_dim]
|
||||
// Only protect against NaN/Inf - do NOT clamp decay value
|
||||
float safe_decay = decay;
|
||||
if (isnan(safe_decay) || isinf(safe_decay)) {
|
||||
safe_decay = 1.0f;
|
||||
}
|
||||
|
||||
for (int out_dim = tid; out_dim < head_dim; out_dim += blockDim.x) {
|
||||
for (int row = 0; row < head_dim; row++) {
|
||||
float state_val = state_dst[row + out_dim * head_dim];
|
||||
|
||||
// state_new[row][out_dim] = decay * state[row][out_dim] + v_new[row] * k[out_dim]
|
||||
// Fix: outer product matches decomposed path: state[v_idx, k_idx] += v_new[v_idx] * k[k_idx]
|
||||
float new_state_val = safe_decay * state_val + sVNew[row] * sKBeta[out_dim];
|
||||
|
||||
// Clamp state to prevent overflow
|
||||
new_state_val = fminf(fmaxf(new_state_val, -1e6f), 1e6f);
|
||||
state_dst[row + out_dim * head_dim] = new_state_val;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Write output
|
||||
for (int i = tid; i < head_dim; i += blockDim.x) {
|
||||
out_base[t * out_token_stride + i] = sOut[i];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
static void delta_net_f32_cuda(
|
||||
const float * q,
|
||||
const float * k,
|
||||
const float * v,
|
||||
const float * g,
|
||||
const float * beta,
|
||||
const float * state_in,
|
||||
float * dst,
|
||||
const int64_t head_dim,
|
||||
const int64_t n_tokens,
|
||||
const int64_t n_heads,
|
||||
const int64_t n_seqs,
|
||||
const float eps,
|
||||
const int device_id,
|
||||
const int cc, // compute capability (e.g., 890 for SM 8.9, 1200 for SM 12.0)
|
||||
cudaStream_t stream) {
|
||||
GGML_UNUSED(device_id);
|
||||
GGML_UNUSED(cc);
|
||||
|
||||
const int64_t output_offset = head_dim * n_tokens * n_heads * n_seqs;
|
||||
|
||||
// One block per (batch, head) pair
|
||||
const int num_blocks = n_seqs * n_heads;
|
||||
constexpr int threads_per_block = 512; //256;
|
||||
|
||||
// Shared memory: 9 * head_dim (for Q, K, V, KBeta, VBeta, Out, KCumdecay, VPrime, VNew)
|
||||
// Plus 6 floats for Norm[2], g_val, beta_val, decay, attn_score
|
||||
const size_t smem_size = (9 * head_dim + 6) * sizeof(float);
|
||||
|
||||
// Use templated kernel for common head dimensions, generic for others
|
||||
if (head_dim == 64) {
|
||||
delta_net_recurrent_f32<64, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
|
||||
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
|
||||
} else if (head_dim == 128) {
|
||||
GGML_ASSERT(num_blocks % 8 == 0);
|
||||
delta_net_recurrent_f32<128, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
|
||||
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
|
||||
} else {
|
||||
delta_net_recurrent_generic_f32<<<num_blocks, threads_per_block, smem_size, stream>>>(
|
||||
q, k, v, g, beta, state_in, dst, head_dim, n_tokens, n_heads, n_seqs, output_offset, eps);
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
}
|
||||
|
||||
void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0]; // q
|
||||
const ggml_tensor * src1 = dst->src[1]; // k
|
||||
const ggml_tensor * src2 = dst->src[2]; // v
|
||||
const ggml_tensor * src3 = dst->src[3]; // g
|
||||
const ggml_tensor * src4 = dst->src[4]; // beta
|
||||
const ggml_tensor * src5 = dst->src[5]; // state
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t head_dim = src0->ne[0];
|
||||
const int64_t n_tokens = src0->ne[1];
|
||||
const int64_t n_heads = src0->ne[2];
|
||||
const int64_t n_seqs = src0->ne[3];
|
||||
|
||||
// Dimension validation
|
||||
// Q/K: [head_dim, n_tokens, n_heads, n_seqs]
|
||||
GGML_ASSERT(src1->ne[0] == head_dim && src1->ne[1] == n_tokens && src1->ne[2] == n_heads && src1->ne[3] == n_seqs);
|
||||
// V: [head_dim, n_tokens, n_heads, n_seqs]
|
||||
GGML_ASSERT(src2->ne[0] == head_dim && src2->ne[1] == n_tokens && src2->ne[2] == n_heads && src2->ne[3] == n_seqs);
|
||||
// G: [n_tokens, 1, n_heads, n_seqs]
|
||||
GGML_ASSERT(src3->ne[0] == n_tokens && src3->ne[1] == 1 && src3->ne[2] == n_heads && src3->ne[3] == n_seqs);
|
||||
// Beta: [1, n_tokens, n_heads, n_seqs]
|
||||
GGML_ASSERT(src4->ne[0] == 1 && src4->ne[1] == n_tokens && src4->ne[2] == n_heads && src4->ne[3] == n_seqs);
|
||||
// State: [head_dim, head_dim*n_heads, 1, n_seqs]
|
||||
GGML_ASSERT(src5->ne[0] == head_dim && src5->ne[1] == head_dim * n_heads && src5->ne[2] == 1 && src5->ne[3] == n_seqs);
|
||||
|
||||
// Verify output tensor size
|
||||
const int64_t output_size = head_dim * n_tokens * n_heads * n_seqs;
|
||||
const int64_t state_size = head_dim * head_dim * n_heads * n_seqs;
|
||||
GGML_ASSERT(ggml_nelements(dst) == output_size + state_size);
|
||||
|
||||
const float eps = 1e-6f;
|
||||
|
||||
GGML_ASSERT(head_dim <= 256); // Reasonable limit for shared memory
|
||||
|
||||
// Get device info from ctx (avoids calling CUDA runtime APIs inside dispatch)
|
||||
const int device_id = ctx.device;
|
||||
const int cc = ggml_cuda_info().devices[device_id].cc;
|
||||
|
||||
delta_net_f32_cuda(
|
||||
(const float *)src0->data,
|
||||
(const float *)src1->data,
|
||||
(const float *)src2->data,
|
||||
(const float *)src3->data,
|
||||
(const float *)src4->data,
|
||||
(const float *)src5->data,
|
||||
(float *)dst->data,
|
||||
head_dim, n_tokens, n_heads, n_seqs, eps,
|
||||
device_id, cc,
|
||||
ctx.stream());
|
||||
|
||||
}
|
||||
3
ggml/src/ggml-cuda/delta-net.cuh
Normal file
3
ggml/src/ggml-cuda/delta-net.cuh
Normal file
@@ -0,0 +1,3 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
200
ggml/src/ggml.c
200
ggml/src/ggml.c
@@ -4277,6 +4277,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"TRI",
|
||||
"FILL",
|
||||
"SOLVE_TRI",
|
||||
"DELTA_NET",
|
||||
|
||||
"MAP_UNARY",
|
||||
"MAP_BINARY",
|
||||
@@ -4299,7 +4300,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"FUSED_NORM",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 100, "GGML_OP_COUNT != 100");
|
||||
static_assert(GGML_OP_COUNT == 101, "GGML_OP_COUNT != 101");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@@ -4395,6 +4396,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"tri(x)",
|
||||
"fill(x)",
|
||||
"solve_tri(x)",
|
||||
"delta_net",
|
||||
|
||||
"f(x)",
|
||||
"f(x,y)",
|
||||
@@ -4417,7 +4419,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"norm(x,y)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 100, "GGML_OP_COUNT != 100");
|
||||
static_assert(GGML_OP_COUNT == 101, "GGML_OP_COUNT != 101");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@@ -9869,6 +9871,59 @@ struct ggml_tensor * ggml_tri(
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_delta_net(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * q,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * g,
|
||||
struct ggml_tensor * beta,
|
||||
struct ggml_tensor * state) {
|
||||
GGML_ASSERT(ggml_is_contiguous(q));
|
||||
GGML_ASSERT(ggml_is_contiguous(k));
|
||||
GGML_ASSERT(ggml_is_contiguous(v));
|
||||
GGML_ASSERT(ggml_is_contiguous(g));
|
||||
GGML_ASSERT(ggml_is_contiguous(beta));
|
||||
GGML_ASSERT(ggml_is_contiguous(state));
|
||||
|
||||
GGML_ASSERT(q->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(k->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(v->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(g->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(beta->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(state->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t S_k = q->ne[0];
|
||||
const int64_t n_tokens = q->ne[1];
|
||||
const int64_t H_k = q->ne[2];
|
||||
const int64_t n_seqs = q->ne[3];
|
||||
|
||||
const int64_t S_v = v->ne[0];
|
||||
const int64_t H_v = v->ne[2];
|
||||
|
||||
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == n_tokens && k->ne[2] == H_k && k->ne[3] == n_seqs);
|
||||
GGML_ASSERT(v->ne[1] == n_tokens && v->ne[3] == n_seqs);
|
||||
GGML_ASSERT(g->ne[0] == n_tokens && g->ne[1] == 1 && g->ne[2] == H_k && g->ne[3] == n_seqs);
|
||||
GGML_ASSERT(beta->ne[0] == 1 && beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[3] == n_seqs);
|
||||
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
|
||||
GGML_ASSERT(H_k == H_v);
|
||||
|
||||
const int64_t output_size = S_v * H_v * n_tokens * n_seqs;
|
||||
const int64_t state_size = S_v * S_v * H_v * n_seqs;
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, output_size + state_size);
|
||||
|
||||
result->op = GGML_OP_DELTA_NET;
|
||||
result->src[0] = q;
|
||||
result->src[1] = k;
|
||||
result->src[2] = v;
|
||||
result->src[3] = g;
|
||||
result->src[4] = beta;
|
||||
result->src[5] = state;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_fill
|
||||
|
||||
static struct ggml_tensor * ggml_fill_impl(
|
||||
@@ -22476,6 +22531,141 @@ static void ggml_compute_forward_solve_tri(const struct ggml_compute_params * pa
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_delta_net
|
||||
|
||||
static void ggml_compute_forward_delta_net_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
const struct ggml_tensor * src2 = dst->src[2];
|
||||
const struct ggml_tensor * src3 = dst->src[3];
|
||||
const struct ggml_tensor * src4 = dst->src[4];
|
||||
const struct ggml_tensor * src5 = dst->src[5];
|
||||
|
||||
const int64_t head_dim = src0->ne[0];
|
||||
const int64_t n_tokens = src0->ne[1];
|
||||
const int64_t n_heads = src0->ne[2];
|
||||
const int64_t n_seqs = src0->ne[3];
|
||||
|
||||
const int64_t output_size = head_dim * n_tokens * n_heads * n_seqs;
|
||||
|
||||
const float * q_data = (const float *) src0->data;
|
||||
const float * k_data = (const float *) src1->data;
|
||||
const float * v_data = (const float *) src2->data;
|
||||
const float * g_data = (const float *) src3->data;
|
||||
const float * beta_data = (const float *) src4->data;
|
||||
const float * state_in = (const float *) src5->data;
|
||||
float * out_data = (float *) dst->data;
|
||||
float * state_out = out_data + output_size;
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
if (iqk_fused_delta_net(head_dim, n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in,
|
||||
out_data, state_out, ith, nth)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t total_heads = n_heads * n_seqs;
|
||||
const int64_t heads_per_thread = (total_heads + nth - 1) / nth;
|
||||
const int64_t h_start = ith * heads_per_thread;
|
||||
const int64_t h_end = (h_start + heads_per_thread < total_heads) ? h_start + heads_per_thread : total_heads;
|
||||
|
||||
const float eps = 1e-12f;
|
||||
const float scale = 1.0f / sqrtf((float) head_dim);
|
||||
|
||||
float * v_new_buf = (float *) malloc(head_dim * sizeof(float));
|
||||
GGML_ASSERT(v_new_buf);
|
||||
|
||||
for (int64_t h_idx = h_start; h_idx < h_end; ++h_idx) {
|
||||
const int64_t batch_idx = h_idx / n_heads;
|
||||
const int64_t head_idx = h_idx % n_heads;
|
||||
|
||||
const int64_t qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens);
|
||||
const int64_t qkv_token_stride = head_dim;
|
||||
const int64_t g_head_offset = batch_idx * (n_tokens * n_heads) + head_idx * n_tokens;
|
||||
const int64_t state_head_offset = batch_idx * (head_dim * head_dim * n_heads) + head_idx * (head_dim * head_dim);
|
||||
const int64_t out_head_offset = batch_idx * (head_dim * n_heads * n_tokens) + head_idx * head_dim;
|
||||
const int64_t out_token_stride = head_dim * n_heads;
|
||||
|
||||
for (int64_t i = 0; i < head_dim * head_dim; ++i) {
|
||||
state_out[state_head_offset + i] = state_in[state_head_offset + i];
|
||||
}
|
||||
|
||||
float * state = state_out + state_head_offset;
|
||||
|
||||
for (int64_t t = 0; t < n_tokens; ++t) {
|
||||
const float * q_t = q_data + qkv_head_offset + t * qkv_token_stride;
|
||||
const float * k_t = k_data + qkv_head_offset + t * qkv_token_stride;
|
||||
const float * v_t = v_data + qkv_head_offset + t * qkv_token_stride;
|
||||
|
||||
const float g_val = g_data[g_head_offset + t];
|
||||
const float beta_raw = beta_data[g_head_offset + t];
|
||||
|
||||
float q_norm_sq = 0.0f;
|
||||
float k_norm_sq = 0.0f;
|
||||
for (int64_t i = 0; i < head_dim; ++i) {
|
||||
q_norm_sq += q_t[i] * q_t[i];
|
||||
k_norm_sq += k_t[i] * k_t[i];
|
||||
}
|
||||
const float q_norm_inv = 1.0f / sqrtf(q_norm_sq + eps);
|
||||
const float k_norm_inv = 1.0f / sqrtf(k_norm_sq + eps);
|
||||
|
||||
const float beta_val = 1.0f / (1.0f + expf(-beta_raw));
|
||||
const float decay = expf(fminf(g_val, 50.0f));
|
||||
|
||||
float attn_score = 0.0f;
|
||||
for (int64_t i = 0; i < head_dim; ++i) {
|
||||
attn_score += (k_t[i] * k_norm_inv) * (q_t[i] * q_norm_inv * scale);
|
||||
}
|
||||
|
||||
float * out_t = out_data + out_head_offset + t * out_token_stride;
|
||||
|
||||
for (int64_t row = 0; row < head_dim; ++row) {
|
||||
float v_prime = 0.0f;
|
||||
float out_val = 0.0f;
|
||||
|
||||
for (int64_t col = 0; col < head_dim; ++col) {
|
||||
const float k_col = k_t[col];
|
||||
const float q_col = q_t[col];
|
||||
const float s = state[row + col * head_dim];
|
||||
|
||||
v_prime += s * k_col;
|
||||
out_val += s * q_col;
|
||||
}
|
||||
|
||||
const float v_new = v_t[row] * beta_val - v_prime * beta_val * decay * k_norm_inv;
|
||||
v_new_buf[row] = v_new;
|
||||
out_t[row] = out_val * decay * q_norm_inv * scale + v_new * attn_score;
|
||||
}
|
||||
|
||||
for (int64_t col = 0; col < head_dim; ++col) {
|
||||
const float k_col = k_t[col] * k_norm_inv;
|
||||
for (int64_t row = 0; row < head_dim; ++row) {
|
||||
float s = state[row + col * head_dim];
|
||||
s = decay * s + v_new_buf[row] * k_col;
|
||||
state[row + col * head_dim] = fminf(fmaxf(s, -1e6f), 1e6f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
free(v_new_buf);
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_delta_net(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (dst->src[0]->type) {
|
||||
case GGML_TYPE_F32:
|
||||
ggml_compute_forward_delta_net_f32(params, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_win_part
|
||||
|
||||
static void ggml_compute_forward_win_part_f32(
|
||||
@@ -24202,6 +24392,10 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
|
||||
{
|
||||
ggml_compute_forward_solve_tri(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_DELTA_NET:
|
||||
{
|
||||
ggml_compute_forward_delta_net(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_WIN_PART:
|
||||
{
|
||||
ggml_compute_forward_win_part(params, tensor);
|
||||
@@ -25260,6 +25454,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
case GGML_OP_TRI:
|
||||
case GGML_OP_FILL:
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
case GGML_OP_DELTA_NET:
|
||||
{
|
||||
GGML_ABORT("fatal error"); // TODO: not implemented
|
||||
}
|
||||
@@ -25990,6 +26185,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_FUSED_UP_GATE:
|
||||
case GGML_OP_OUT_PROD:
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
case GGML_OP_DELTA_NET:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
|
||||
@@ -1383,6 +1383,155 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
template <int head_dim>
|
||||
void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
|
||||
const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data,
|
||||
const float * state_in, float * out_data, float * state_out, int ith, int nth) {
|
||||
const int total_heads = n_heads * n_seqs;
|
||||
const int heads_per_thread = (total_heads + nth - 1) / nth;
|
||||
const int h_start = ith * heads_per_thread;
|
||||
const int h_end = (h_start + heads_per_thread < total_heads) ? h_start + heads_per_thread : total_heads;
|
||||
|
||||
#ifdef __AVX2__
|
||||
static_assert(head_dim % 8 == 0);
|
||||
#endif
|
||||
|
||||
const float eps = 1e-6f;
|
||||
const float scale = 1.0f / sqrtf((float) head_dim);
|
||||
|
||||
float v_new_buf[head_dim];
|
||||
float v_prime[head_dim], out_val[head_dim];
|
||||
|
||||
for (int h_idx = h_start; h_idx < h_end; ++h_idx) {
|
||||
const int batch_idx = h_idx / n_heads;
|
||||
const int head_idx = h_idx % n_heads;
|
||||
|
||||
const int qkv_head_offset = batch_idx * (head_dim * n_tokens * n_heads) + head_idx * (head_dim * n_tokens);
|
||||
const int qkv_token_stride = head_dim;
|
||||
const int g_head_offset = batch_idx * (n_tokens * n_heads) + head_idx * n_tokens;
|
||||
const int state_head_offset = batch_idx * (head_dim * head_dim * n_heads) + head_idx * (head_dim * head_dim);
|
||||
const int out_head_offset = batch_idx * (head_dim * n_heads * n_tokens) + head_idx * head_dim;
|
||||
const int out_token_stride = head_dim * n_heads;
|
||||
|
||||
for (int i = 0; i < head_dim * head_dim; ++i) {
|
||||
state_out[state_head_offset + i] = state_in[state_head_offset + i];
|
||||
}
|
||||
|
||||
float * state = state_out + state_head_offset;
|
||||
|
||||
for (int t = 0; t < n_tokens; ++t) {
|
||||
const float * q_t = q_data + qkv_head_offset + t * qkv_token_stride;
|
||||
const float * k_t = k_data + qkv_head_offset + t * qkv_token_stride;
|
||||
const float * v_t = v_data + qkv_head_offset + t * qkv_token_stride;
|
||||
|
||||
const float g_val = g_data[g_head_offset + t];
|
||||
const float beta_raw = beta_data[g_head_offset + t];
|
||||
|
||||
float q_norm_sq = 0.0f;
|
||||
float k_norm_sq = 0.0f;
|
||||
float kq_sum = 0.0f;
|
||||
#ifdef __AVX2__
|
||||
auto vqsum = _mm256_setzero_ps();
|
||||
auto vksum = _mm256_setzero_ps();
|
||||
auto vqksum = _mm256_setzero_ps();
|
||||
for (int i = 0; i < head_dim; i += 8) {
|
||||
auto vq = _mm256_loadu_ps(q_t + i);
|
||||
auto vk = _mm256_loadu_ps(k_t + i);
|
||||
vqsum = _mm256_fmadd_ps(vq, vq, vqsum);
|
||||
vksum = _mm256_fmadd_ps(vk, vk, vksum);
|
||||
vqksum = _mm256_fmadd_ps(vk, vq, vqksum);
|
||||
}
|
||||
q_norm_sq = hsum_float_8(vqsum);
|
||||
k_norm_sq = hsum_float_8(vksum);
|
||||
kq_sum = hsum_float_8(vqksum);
|
||||
#else
|
||||
for (int i = 0; i < head_dim; ++i) {
|
||||
q_norm_sq += q_t[i] * q_t[i];
|
||||
k_norm_sq += k_t[i] * k_t[i];
|
||||
kq_sum += k_t[i] * q_t[i];
|
||||
}
|
||||
#endif
|
||||
const float q_norm_inv = 1.0f / sqrtf(q_norm_sq + eps);
|
||||
const float k_norm_inv = 1.0f / sqrtf(k_norm_sq + eps);
|
||||
|
||||
const float beta_val = 1.0f / (1.0f + expf(-beta_raw));
|
||||
const float decay = expf(fminf(g_val, 50.0f));
|
||||
|
||||
float attn_score = kq_sum * k_norm_inv * q_norm_inv * scale;
|
||||
|
||||
//float attn_score = 0.0f;
|
||||
//for (int i = 0; i < head_dim; ++i) {
|
||||
// attn_score += (k_t[i] * k_norm_inv) * (q_t[i] * q_norm_inv * scale);
|
||||
//}
|
||||
|
||||
float * out_t = out_data + out_head_offset + t * out_token_stride;
|
||||
|
||||
std::memset(v_prime, 0, head_dim*sizeof(float));
|
||||
std::memset(out_val, 0, head_dim*sizeof(float));
|
||||
for (int col = 0; col < head_dim; ++col) {
|
||||
const float k_col = k_t[col];
|
||||
const float q_col = q_t[col];
|
||||
for (int row = 0; row < head_dim; ++row) {
|
||||
const float s = state[row + col * head_dim];
|
||||
v_prime[row] += s * k_col;
|
||||
out_val[row] += s * q_col;
|
||||
}
|
||||
}
|
||||
for (int row = 0; row < head_dim; ++row) {
|
||||
const float v_new = v_t[row] * beta_val - v_prime[row] * beta_val * decay * k_norm_inv;
|
||||
v_new_buf[row] = v_new;
|
||||
out_t[row] = out_val[row] * decay * q_norm_inv * scale + v_new * attn_score;
|
||||
}
|
||||
|
||||
#ifdef __AVX2__
|
||||
auto vd = _mm256_set1_ps(decay);
|
||||
auto vmin = _mm256_set1_ps(-1e6f);
|
||||
auto vmax = _mm256_set1_ps( 1e6f);
|
||||
for (int col = 0; col < head_dim; ++col) {
|
||||
auto vk = _mm256_set1_ps(k_t[col] * k_norm_inv);
|
||||
for (int row = 0; row < head_dim; row += 8) {
|
||||
auto vs = _mm256_loadu_ps(state + col * head_dim + row);
|
||||
auto vn = _mm256_loadu_ps(v_new_buf + row);
|
||||
vs = _mm256_fmadd_ps(vn, vk, _mm256_mul_ps(vs, vd));
|
||||
auto mask_l = _mm256_cmp_ps(vs, vmin, _CMP_LT_OQ);
|
||||
auto mask_u = _mm256_cmp_ps(vs, vmax, _CMP_GT_OQ);
|
||||
vs = _mm256_or_ps(_mm256_and_ps(mask_l, vmin), _mm256_andnot_ps(mask_l, vs));
|
||||
vs = _mm256_or_ps(_mm256_and_ps(mask_u, vmax), _mm256_andnot_ps(mask_u, vs));
|
||||
_mm256_storeu_ps(state + col * head_dim + row, vs);
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int col = 0; col < head_dim; ++col) {
|
||||
const float k_col = k_t[col] * k_norm_inv;
|
||||
for (int row = 0; row < head_dim; ++row) {
|
||||
float s = state[row + col * head_dim];
|
||||
s = decay * s + v_new_buf[row] * k_col;
|
||||
state[row + col * head_dim] = fminf(fmaxf(s, -1e6f), 1e6f);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool iqk_fused_delta_net(int head_dim, int n_heads, int n_tokens, int n_seqs,
|
||||
const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data,
|
||||
const float * state_in, float * out_data, float * state_out, int ith, int nth) {
|
||||
if (head_dim != 64 && head_dim != 128) {
|
||||
return false;
|
||||
}
|
||||
if (head_dim == 64) {
|
||||
iqk_fused_delta_net_impl<64>(n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in,
|
||||
out_data, state_out, ith, nth);
|
||||
} else {
|
||||
iqk_fused_delta_net_impl<128>(n_heads, n_tokens, n_seqs, q_data, k_data, v_data, g_data, beta_data, state_in,
|
||||
out_data, state_out, ith, nth);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
#else // IQK_IMPLEMENT
|
||||
|
||||
#include "ggml-impl.h"
|
||||
@@ -1416,4 +1565,11 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long /*Nx*/, long /*Ny*/, long /*n
|
||||
return false;
|
||||
}
|
||||
|
||||
bool iqk_fused_delta_net(int, int, int, int,
|
||||
const float *, const float *, const float *, const float *, const float *,
|
||||
const float *, float *, float *, int, int) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
@@ -73,6 +73,10 @@ IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
|
||||
IQK_API void iqk_topk_moe(int n_experts, int n_experts_used, int nrows, const float * logits,
|
||||
float * weights, int32_t * ids, int ith, int nth);
|
||||
|
||||
IQK_API bool iqk_fused_delta_net(int head_dim, int n_heads, int n_tokens, int n_seqs,
|
||||
const float * q_data, const float * k_data, const float * v_data, const float * g_data, const float * beta_data,
|
||||
const float * state_in, float * out_data, float * state_out, int ith, int nth);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user