mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-28 17:14:17 +00:00
@@ -359,7 +359,7 @@ struct gpt_params {
|
||||
bool split_mode_graph_scheduling = false; // if true, force split mode graph scheduling
|
||||
//bool split_mode_f16 = true; // if true, intermediate results will be cast to f16 before copying to other GPUs to perform reduce ops
|
||||
bool scheduler_async = false; // if true, in split mode graph the scheduler will use multiple threads to evaluate the graph
|
||||
int fused_delta_net = 65536; // use fused delta-net if number of tokens in the batch is less than this value
|
||||
int fused_delta_net = 0; // use fused delta-net if number of tokens in the batch is less than this value
|
||||
bool has_mtp = false; // enable MTP if supported by the model
|
||||
|
||||
std::string cache_type_k = "f16"; // KV cache data type for the K
|
||||
|
||||
@@ -271,7 +271,7 @@ struct cmd_params {
|
||||
bool muge = false;
|
||||
bool rcache = false;
|
||||
bool sas = false;
|
||||
int fdn = 65536; // fdn = fused delta net
|
||||
int fdn = 0; // fdn = fused delta net
|
||||
bool print_overrides = false;
|
||||
output_formats output_format;
|
||||
output_formats output_format_stderr;
|
||||
@@ -317,7 +317,7 @@ static const cmd_params cmd_params_defaults = {
|
||||
/* muge */ false,
|
||||
/* rcache */ false,
|
||||
/* sas */ false,
|
||||
/* fdn */ 65536,
|
||||
/* fdn */ 0,
|
||||
/* print_overrides */ false,
|
||||
/* output_format */ MARKDOWN,
|
||||
/* output_format_stderr */ NONE,
|
||||
|
||||
@@ -41,12 +41,12 @@ __global__ void delta_net_recurrent_f32(
|
||||
const int64_t n_seqs,
|
||||
const int64_t output_offset, // offset where state starts in output
|
||||
const float eps) {
|
||||
constexpr int warps_per_head = HEAD_DIM/WARP_SIZE;
|
||||
const int batch_idx = blockIdx.x / (warps_per_head*n_heads);
|
||||
const int sub_head_idx = blockIdx.x % (warps_per_head*n_heads);
|
||||
const int head_idx = sub_head_idx / warps_per_head;
|
||||
const int sub_idx = sub_head_idx % warps_per_head;
|
||||
const int batch_idx = blockIdx.x / n_heads;
|
||||
const int head_idx = blockIdx.x % n_heads;
|
||||
const int tid = threadIdx.x;
|
||||
const int warp_id = tid / WARP_SIZE; // 0-7 for 256 threads
|
||||
const int lane_id = tid % WARP_SIZE; // 0-31
|
||||
constexpr int NUM_WARPS = block_size/WARP_SIZE;
|
||||
|
||||
// Strides for input tensors (column-major)
|
||||
// Q/K/V: [HEAD_DIM, n_tokens, n_heads, n_seqs]
|
||||
@@ -83,34 +83,32 @@ __global__ void delta_net_recurrent_f32(
|
||||
extern __shared__ float smem[];
|
||||
float * sQ = smem; // HEAD_DIM
|
||||
float * sK = sQ + HEAD_DIM; // HEAD_DIM
|
||||
float * sV = sK + HEAD_DIM; // HEAD_DIM
|
||||
float * sVNew = sV + HEAD_DIM; // HEAD_DIM
|
||||
|
||||
const float scale = rsqrtf((float)HEAD_DIM);
|
||||
|
||||
__shared__ float sum_helper[block_size/WARP_SIZE];
|
||||
|
||||
constexpr int num_warps = block_size/WARP_SIZE;
|
||||
const int row = tid % WARP_SIZE;
|
||||
const int col_idx_0 = tid / WARP_SIZE;
|
||||
const int row_out = row + sub_idx * WARP_SIZE;
|
||||
|
||||
// Keep the state in registers, copy the final state to its destination at the end
|
||||
float state_local[HEAD_DIM/num_warps];
|
||||
for (int i = 0; i < HEAD_DIM/num_warps; ++i) {
|
||||
int col = num_warps*i + col_idx_0;
|
||||
state_local[i] = state_src[col*HEAD_DIM + row_out];
|
||||
// Copy initial state to output buffer (will be updated in place)
|
||||
for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += block_size) {
|
||||
state_dst[i] = state_src[i];
|
||||
}
|
||||
|
||||
constexpr int WARP_SIZE_S = WARP_SIZE + 1;
|
||||
constexpr int num_stored_rows = block_size/WARP_SIZE;
|
||||
__shared__ float all_sum[2*WARP_SIZE_S*num_stored_rows];
|
||||
constexpr int HEAD_DIM_S = HEAD_DIM + 1;
|
||||
constexpr int num_stored_rows = block_size >= HEAD_DIM && block_size % HEAD_DIM == 0 ? block_size/HEAD_DIM : NUM_WARPS;
|
||||
__shared__ float all_sum[2*HEAD_DIM_S*num_stored_rows];
|
||||
auto all_sum1 = all_sum;
|
||||
auto all_sum2 = all_sum1 + WARP_SIZE_S*num_stored_rows;
|
||||
auto all_sum2 = all_sum1 + HEAD_DIM_S*num_stored_rows;
|
||||
|
||||
// Process each token sequentially
|
||||
for (int64_t t = 0; t < n_tokens; t++) {
|
||||
|
||||
float sum_kq = 0.0f;
|
||||
for (int i = tid; i < HEAD_DIM; i += block_size) {
|
||||
sQ[i] = q_ptr[t * qkv_stride_token + i] * scale;
|
||||
sK[i] = k_ptr[t * qkv_stride_token + i];
|
||||
sV[i] = v_ptr[t * qkv_stride_token + i];
|
||||
sum_kq += sK[i] * sQ[i];
|
||||
}
|
||||
|
||||
@@ -119,44 +117,281 @@ __global__ void delta_net_recurrent_f32(
|
||||
float beta_val = sigmoid_f(beta_ptr[t]);
|
||||
float decay = expf(fminf(g_ptr[t], 50.0f));
|
||||
|
||||
float sum1 = 0, sum2 = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HEAD_DIM/num_warps; ++i) {
|
||||
int col = num_warps*i + col_idx_0;
|
||||
sum1 += state_local[i] * sK[col];
|
||||
sum2 += state_local[i] * sQ[col];
|
||||
}
|
||||
all_sum1[col_idx_0*WARP_SIZE_S + row] = sum1;
|
||||
all_sum2[col_idx_0*WARP_SIZE_S + row] = sum2;
|
||||
if constexpr (block_size >= HEAD_DIM && block_size % HEAD_DIM == 0) {
|
||||
int idx = tid / HEAD_DIM;
|
||||
int row_out = tid % HEAD_DIM;
|
||||
float sum1 = 0, sum2 = 0;
|
||||
#pragma unroll
|
||||
for (int col = idx; col < HEAD_DIM; col += block_size/HEAD_DIM) {
|
||||
float sval = state_dst[row_out + col * HEAD_DIM];
|
||||
sum1 += sval * sK[col];
|
||||
sum2 += sval * sQ[col];
|
||||
}
|
||||
all_sum1[idx*HEAD_DIM_S + row_out] = sum1;
|
||||
all_sum2[idx*HEAD_DIM_S + row_out] = sum2;
|
||||
|
||||
__syncthreads();
|
||||
__syncthreads();
|
||||
|
||||
sum1 = sum2 = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < block_size/WARP_SIZE; ++i) {
|
||||
sum1 += all_sum1[i*WARP_SIZE_S + row];
|
||||
sum2 += all_sum2[i*WARP_SIZE_S + row];
|
||||
}
|
||||
// To be honest, I don't understand why we need this sync. But without it I observe results varying from run to run
|
||||
__syncthreads();
|
||||
if (idx == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 1; i < block_size/HEAD_DIM; ++i) {
|
||||
sum1 += all_sum1[i*HEAD_DIM_S + row_out];
|
||||
sum2 += all_sum2[i*HEAD_DIM_S + row_out];
|
||||
}
|
||||
sVNew[row_out] = sV[row_out] * beta_val - sum1 * beta_val * decay;
|
||||
float v_attn = sVNew[row_out] * attn_score;
|
||||
out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn;
|
||||
}
|
||||
__syncthreads();
|
||||
} else {
|
||||
for (int row_out = lane_id; row_out < HEAD_DIM; row_out += WARP_SIZE) {
|
||||
float sum1 = 0.0f;
|
||||
float sum2 = 0.0f;
|
||||
#pragma unroll
|
||||
for (int col = warp_id; col < HEAD_DIM; col += NUM_WARPS) {
|
||||
float sval = state_dst[row_out + col * HEAD_DIM];
|
||||
sum1 += sval * sK[col];
|
||||
sum2 += sval * sQ[col];
|
||||
}
|
||||
all_sum1[warp_id*HEAD_DIM_S + row_out] = sum1;
|
||||
all_sum2[warp_id*HEAD_DIM_S + row_out] = sum2;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float sv_new = beta_val * (v_ptr[t * qkv_stride_token + row_out] - sum1 * decay);
|
||||
if (col_idx_0 == 0) {
|
||||
out_base[t * out_token_stride + row_out] = sum2 * decay + sv_new * attn_score;
|
||||
for (int row_out = tid; row_out < HEAD_DIM; row_out += block_size) {
|
||||
float sum1 = all_sum1[row_out];
|
||||
float sum2 = all_sum2[row_out];
|
||||
#pragma unroll
|
||||
for (int i = 1; i < NUM_WARPS; ++i) {
|
||||
sum1 += all_sum1[row_out + i*HEAD_DIM_S];
|
||||
sum2 += all_sum2[row_out + i*HEAD_DIM_S];
|
||||
}
|
||||
sVNew[row_out] = sV[row_out] * beta_val - sum1 * beta_val * decay;
|
||||
float v_attn = sVNew[row_out] * attn_score;
|
||||
out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
for (int i = 0; i < HEAD_DIM/num_warps; ++i) {
|
||||
int col = num_warps*i + col_idx_0;
|
||||
float new_state_val = decay * state_local[i] + sv_new * sK[col];
|
||||
new_state_val = fminf(fmaxf(new_state_val, -1e6f), 1e6f);
|
||||
state_local[i] = new_state_val;
|
||||
for (int out_dim = warp_id; out_dim < HEAD_DIM; out_dim += NUM_WARPS) {
|
||||
float k_col = sK[out_dim];
|
||||
#pragma unroll
|
||||
for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) {
|
||||
float state_val = state_dst[row + out_dim * HEAD_DIM];
|
||||
float new_state_val = decay * state_val + sVNew[row] * k_col; //sK[out_dim];
|
||||
new_state_val = fminf(fmaxf(new_state_val, -1e6f), 1e6f);
|
||||
state_dst[row + out_dim * HEAD_DIM] = new_state_val;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
// Copy the final state to its destination
|
||||
for (int i = 0; i < HEAD_DIM/num_warps; ++i) {
|
||||
int col = num_warps*i + col_idx_0;
|
||||
state_dst[col*HEAD_DIM + row_out] = state_local[i];
|
||||
}
|
||||
|
||||
// Generic kernel that handles any HEAD_DIM at runtime (slower but flexible)
|
||||
__global__ void delta_net_recurrent_generic_f32(
|
||||
const float * __restrict__ q,
|
||||
const float * __restrict__ k,
|
||||
const float * __restrict__ v,
|
||||
const float * __restrict__ g,
|
||||
const float * __restrict__ beta_in,
|
||||
const float * __restrict__ state_in,
|
||||
float * __restrict__ dst,
|
||||
const int64_t head_dim,
|
||||
const int64_t n_tokens,
|
||||
const int64_t n_heads,
|
||||
const int64_t n_seqs,
|
||||
const int64_t output_offset,
|
||||
const float eps) {
|
||||
const int batch_idx = blockIdx.x / n_heads;
|
||||
const int head_idx = blockIdx.x % n_heads;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
// Strides (column-major)
|
||||
const int64_t qkv_stride_token = head_dim;
|
||||
const int64_t qkv_stride_head = head_dim * n_tokens;
|
||||
const int64_t qkv_stride_batch = head_dim * n_tokens * n_heads;
|
||||
|
||||
const int64_t g_stride_head = n_tokens;
|
||||
const int64_t g_stride_batch = n_tokens * n_heads;
|
||||
|
||||
const int64_t state_head_offset = head_idx * head_dim * head_dim;
|
||||
const int64_t state_batch_stride = head_dim * head_dim * n_heads;
|
||||
|
||||
// Pointers
|
||||
const float * q_ptr = q + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head;
|
||||
const float * k_ptr = k + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head;
|
||||
const float * v_ptr = v + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head;
|
||||
const float * g_ptr = g + batch_idx * g_stride_batch + head_idx * g_stride_head;
|
||||
const float * beta_ptr = beta_in + batch_idx * g_stride_batch + head_idx * g_stride_head;
|
||||
const float * state_src = state_in + batch_idx * state_batch_stride + state_head_offset;
|
||||
|
||||
// Output layout: [head_v_dim, num_v_heads, n_seq_tokens, n_seqs]
|
||||
float * out_base = dst + batch_idx * (head_dim * n_heads * n_tokens) + head_idx * head_dim;
|
||||
const int64_t out_token_stride = head_dim * n_heads;
|
||||
float * state_dst = dst + output_offset + batch_idx * state_batch_stride + state_head_offset;
|
||||
|
||||
// Shared memory for scalars (outside loop)
|
||||
__shared__ float shared_g_val, shared_beta_val, shared_decay, shared_attn_score;
|
||||
|
||||
// Dynamic shared memory
|
||||
extern __shared__ float smem[];
|
||||
float * sQ = smem;
|
||||
float * sK = sQ + head_dim;
|
||||
float * sV = sK + head_dim;
|
||||
float * sKBeta = sV + head_dim; // plain k for state update
|
||||
float * sVBeta = sKBeta + head_dim; // v * sigmoid(beta)
|
||||
float * sOut = sVBeta + head_dim;
|
||||
float * sKCumdecay = sOut + head_dim; // k * sigmoid(beta) * exp(g)
|
||||
float * sVPrime = sKCumdecay + head_dim; // state @ k_cumdecay
|
||||
float * sVNew = sVPrime + head_dim; // v_beta - v_prime
|
||||
float * sNorm = sVNew + head_dim;
|
||||
|
||||
const float scale = rsqrtf((float)head_dim);
|
||||
|
||||
// Copy initial state to output buffer
|
||||
for (int i = tid; i < head_dim * head_dim; i += blockDim.x) {
|
||||
int col = i / head_dim;
|
||||
int row = i % head_dim;
|
||||
state_dst[row + col * head_dim] = state_src[row + col * head_dim];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Process each token
|
||||
for (int64_t t = 0; t < n_tokens; t++) {
|
||||
if (tid < 2) sNorm[tid] = 0.0f;
|
||||
__syncthreads();
|
||||
|
||||
// Load Q, K, V
|
||||
for (int i = tid; i < head_dim; i += blockDim.x) {
|
||||
sQ[i] = q_ptr[t * qkv_stride_token + i];
|
||||
sK[i] = k_ptr[t * qkv_stride_token + i];
|
||||
sV[i] = v_ptr[t * qkv_stride_token + i];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// L2 normalize Q and K
|
||||
float q_sq = 0.0f, k_sq = 0.0f;
|
||||
for (int i = tid; i < head_dim; i += blockDim.x) {
|
||||
q_sq += sQ[i] * sQ[i];
|
||||
k_sq += sK[i] * sK[i];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) {
|
||||
q_sq += __shfl_xor_sync(0xffffffff, q_sq, offset);
|
||||
k_sq += __shfl_xor_sync(0xffffffff, k_sq, offset);
|
||||
}
|
||||
|
||||
if (tid % WARP_SIZE == 0) {
|
||||
atomicAdd(&sNorm[0], q_sq);
|
||||
atomicAdd(&sNorm[1], k_sq);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float q_norm = rsqrtf(sNorm[0] + eps);
|
||||
float k_norm = rsqrtf(sNorm[1] + eps);
|
||||
|
||||
for (int i = tid; i < head_dim; i += blockDim.x) {
|
||||
sQ[i] *= q_norm * scale;
|
||||
sK[i] *= k_norm;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Load g and beta, compute decay
|
||||
if (tid == 0) {
|
||||
shared_g_val = g_ptr[t];
|
||||
shared_beta_val = sigmoid_f(beta_ptr[t]);
|
||||
shared_decay = expf(fminf(shared_g_val, 50.0f));
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float beta_val = shared_beta_val;
|
||||
float decay = shared_decay;
|
||||
|
||||
// Compute k_beta, v_beta, k_cumdecay
|
||||
for (int i = tid; i < head_dim; i += blockDim.x) {
|
||||
sKBeta[i] = sK[i];
|
||||
sVBeta[i] = sV[i] * beta_val;
|
||||
sKCumdecay[i] = sK[i] * beta_val * decay;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Compute v_prime = state @ k_cumdecay
|
||||
for (int row_out = tid; row_out < head_dim; row_out += blockDim.x) {
|
||||
float v_prime_val = 0.0f;
|
||||
for (int col = 0; col < head_dim; col++) {
|
||||
// Access state[row_out, col] = state_dst[row_out + col * head_dim] for state @ k
|
||||
v_prime_val += state_dst[row_out + col * head_dim] * sKCumdecay[col];
|
||||
}
|
||||
sVPrime[row_out] = v_prime_val;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Compute v_new = v_beta - v_prime (the value residual)
|
||||
for (int i = tid; i < head_dim; i += blockDim.x) {
|
||||
sVNew[i] = sVBeta[i] - sVPrime[i];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Compute attn_score = dot(k, q) (L2 normalized vectors)
|
||||
if (tid == 0) {
|
||||
float dot_sum = 0.0f;
|
||||
for (int i = 0; i < head_dim; i++) {
|
||||
dot_sum += sK[i] * sQ[i];
|
||||
}
|
||||
shared_attn_score = dot_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Compute output: o[t] = attn_inter + v_attn
|
||||
// attn_inter = state @ (q * exp(g)) = sum_col(state[row_out, col] * q[col] * exp(g))
|
||||
// The decomposed path uses: attn_inter = ggml_mul_mat(state_t, q_g_exp)
|
||||
// Since ggml_mul_mat(A,B) = A^T @ B, attn_inter = state_t^T @ q_g_exp = state @ (q * exp(g))
|
||||
for (int row_out = tid; row_out < head_dim; row_out += blockDim.x) {
|
||||
float attn_inter = 0.0f;
|
||||
|
||||
for (int col = 0; col < head_dim; col++) {
|
||||
// Access state[row_out, col] = state_dst[row_out + col * head_dim] for state @ q
|
||||
float state_val = state_dst[row_out + col * head_dim];
|
||||
attn_inter += sQ[col] * decay * state_val;
|
||||
}
|
||||
|
||||
// v_attn = v_new * attn_score
|
||||
float v_attn = sVNew[row_out] * shared_attn_score;
|
||||
|
||||
// Output = attn_inter + v_attn (correct DeltaNet formula)
|
||||
sOut[row_out] = attn_inter + v_attn;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Update state: state_new = decay * state + outer(v_new, k)
|
||||
// Fixed: outer product orientation matches decomposed: state[v_idx, k_idx] += v_new[v_idx] * k[k_idx]
|
||||
// Uses transposed indexing: state_dst[row + out_dim * head_dim] = state[row][out_dim]
|
||||
// Only protect against NaN/Inf - do NOT clamp decay value
|
||||
float safe_decay = decay;
|
||||
if (isnan(safe_decay) || isinf(safe_decay)) {
|
||||
safe_decay = 1.0f;
|
||||
}
|
||||
|
||||
for (int out_dim = tid; out_dim < head_dim; out_dim += blockDim.x) {
|
||||
for (int row = 0; row < head_dim; row++) {
|
||||
float state_val = state_dst[row + out_dim * head_dim];
|
||||
|
||||
// state_new[row][out_dim] = decay * state[row][out_dim] + v_new[row] * k[out_dim]
|
||||
// Fix: outer product matches decomposed path: state[v_idx, k_idx] += v_new[v_idx] * k[k_idx]
|
||||
float new_state_val = safe_decay * state_val + sVNew[row] * sKBeta[out_dim];
|
||||
|
||||
// Clamp state to prevent overflow
|
||||
new_state_val = fminf(fmaxf(new_state_val, -1e6f), 1e6f);
|
||||
state_dst[row + out_dim * head_dim] = new_state_val;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Write output
|
||||
for (int i = tid; i < head_dim; i += blockDim.x) {
|
||||
out_base[t * out_token_stride + i] = sOut[i];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,32 +416,24 @@ static void delta_net_f32_cuda(
|
||||
|
||||
const int64_t output_offset = head_dim * n_tokens * n_heads * n_seqs;
|
||||
|
||||
if (head_dim != 64 && head_dim != 128) {
|
||||
GGML_ABORT("Unsupported delta net head size");
|
||||
}
|
||||
// One block per (batch, head) pair
|
||||
const int num_blocks = n_seqs * n_heads;
|
||||
constexpr int threads_per_block = 512; //256;
|
||||
|
||||
GGML_ASSERT(head_dim % WARP_SIZE == 0);
|
||||
const int num_blocks = n_seqs * n_heads * (head_dim/WARP_SIZE);
|
||||
const size_t smem_size = 2 * head_dim * sizeof(float);
|
||||
const size_t smem_size = 4 * head_dim * sizeof(float);
|
||||
|
||||
if (n_tokens <= 8) {
|
||||
constexpr int threads_per_block = 256;
|
||||
if (head_dim == 64) {
|
||||
delta_net_recurrent_f32<64, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
|
||||
// Use templated kernel for common head dimensions, generic for others
|
||||
if (head_dim == 64) {
|
||||
delta_net_recurrent_f32<64, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
|
||||
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
|
||||
} else if (head_dim == 128) {
|
||||
GGML_ASSERT(num_blocks % 8 == 0);
|
||||
delta_net_recurrent_f32<128, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
|
||||
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
|
||||
} else {
|
||||
delta_net_recurrent_f32<128, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
|
||||
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
|
||||
}
|
||||
} else {
|
||||
constexpr int threads_per_block = 128;
|
||||
if (head_dim == 64) {
|
||||
delta_net_recurrent_f32<64, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
|
||||
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
|
||||
} else {
|
||||
delta_net_recurrent_f32<128, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
|
||||
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
|
||||
}
|
||||
GGML_ASSERT("Unsupported delta net head size");
|
||||
delta_net_recurrent_generic_f32<<<num_blocks, threads_per_block, smem_size, stream>>>(
|
||||
q, k, v, g, beta, state_in, dst, head_dim, n_tokens, n_heads, n_seqs, output_offset, eps);
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
#include <algorithm>
|
||||
#include <unordered_set>
|
||||
|
||||
#define DELTA_CHUNK_SIZE 64
|
||||
#define QWEN3NEXT_CHUNK_SIZE 64
|
||||
|
||||
delta_net::delta_net(llama_context & _lctx, const llama_batch & _batch) : lctx(_lctx), batch(_batch) {
|
||||
auto & model = lctx.model;
|
||||
@@ -111,7 +111,7 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_delta_net_chunking(ggml
|
||||
cb(g, "g_in", il);
|
||||
cb(state,"state_in", il);
|
||||
|
||||
const int64_t chunk_size = DELTA_CHUNK_SIZE;
|
||||
const int64_t chunk_size = QWEN3NEXT_CHUNK_SIZE;
|
||||
const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
|
||||
const int64_t n_chunks = (n_tokens + pad) / chunk_size;
|
||||
|
||||
@@ -296,8 +296,8 @@ std::pair<ggml_tensor *, ggml_tensor *> delta_net::build_delta_net_chunking(ggml
|
||||
ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
|
||||
S_v, n_tokens, H_v, n_seqs,
|
||||
ggml_row_size(core_attn_out->type, S_v),
|
||||
ggml_row_size(core_attn_out->type, S_v * DELTA_CHUNK_SIZE * n_chunks),
|
||||
ggml_row_size(core_attn_out->type, S_v * DELTA_CHUNK_SIZE * n_chunks * H_v), 0);
|
||||
ggml_row_size(core_attn_out->type, S_v * QWEN3NEXT_CHUNK_SIZE * n_chunks),
|
||||
ggml_row_size(core_attn_out->type, S_v * QWEN3NEXT_CHUNK_SIZE * n_chunks * H_v), 0);
|
||||
cb(output_tokens, "output_tokens", il);
|
||||
|
||||
output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
|
||||
@@ -572,20 +572,19 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
|
||||
|
||||
beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_tok, 1);
|
||||
alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_tok, 1);
|
||||
cb(beta, "beta", il);
|
||||
cb(alpha, "alpha", il);
|
||||
} else {
|
||||
beta = llm_build_context::llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_beta, cur);
|
||||
alpha = llm_build_context::llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_alpha, cur);
|
||||
ggml_build_forward_expand(gf, beta);
|
||||
ggml_build_forward_expand(gf, alpha);
|
||||
cb(beta, "beta", il);
|
||||
cb(alpha, "alpha", il);
|
||||
beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_tok, 1);
|
||||
cb(beta, "beta_reshaped", il);
|
||||
alpha = ggml_reshape_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs);
|
||||
cb(alpha, "alpha_reshaped", il);
|
||||
alpha = llm_build_context::llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_alpha, cur);
|
||||
cb(alpha, "alpha", il);
|
||||
// Why? Don't think this ggml_cont_3d is needed, but lets leave it in for now just in case.
|
||||
alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs);
|
||||
cb(alpha, "alpha_cont", il);
|
||||
}
|
||||
cb(beta, "beta", il);
|
||||
cb(alpha, "alpha", il);
|
||||
ggml_build_forward_expand(gf, beta);
|
||||
ggml_build_forward_expand(gf, alpha);
|
||||
|
||||
@@ -607,13 +606,18 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
|
||||
state_all = ggml_view_2d(ctx0, state_storage, state_dim, qnext_state_slots, state_row_size, 0);
|
||||
|
||||
ggml_tensor * state_dst = ggml_view_2d(ctx0, state_all, state_dim, 1, state_row_size, state_seq_id_local * state_row_size);
|
||||
ggml_tensor * state_f32 = state_dst;
|
||||
if (state_f32->type != GGML_TYPE_F32) {
|
||||
state_f32 = ggml_cast(ctx0, state_f32, GGML_TYPE_F32);
|
||||
}
|
||||
if (reset_state_local) {
|
||||
state_dst = ggml_scale(ctx0, state_dst, 0.0f);
|
||||
cb(state_dst, "state_reset", il);
|
||||
state_f32 = ggml_scale(ctx0, state_f32, 0.0f);
|
||||
cb(state_f32, "state_reset", il);
|
||||
}
|
||||
|
||||
ggml_tensor * conv_state_flat = ggml_view_2d(ctx0, state_dst, conv_state_dim, 1, state_dst->nb[1], 0);
|
||||
ggml_tensor * ssm_state_flat = ggml_view_2d(ctx0, state_dst, ssm_state_dim, 1, state_dst->nb[1], conv_state_dim * ggml_element_size(state_dst));
|
||||
ggml_tensor * conv_state_flat = ggml_view_2d(ctx0, state_f32, conv_state_dim, 1, state_f32->nb[1], 0);
|
||||
ggml_tensor * ssm_state_flat = ggml_view_2d(ctx0, state_f32, ssm_state_dim, 1, state_f32->nb[1],
|
||||
conv_state_dim * ggml_element_size(state_f32));
|
||||
|
||||
ggml_tensor * conv_states = ggml_reshape_3d(ctx0, conv_state_flat, hparams.ssm_d_conv - 1, conv_dim, 1);
|
||||
ggml_tensor * state = ggml_reshape_4d(ctx0, ssm_state_flat, head_v_dim, head_v_dim, num_v_heads, 1);
|
||||
@@ -624,6 +628,8 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
|
||||
ggml_tensor * conv_output_raw = ggml_ssm_conv(ctx0, conv_states, qkv_mixed, model.layers[il].ssm_conv1d, inp_s_seq_qnext);
|
||||
cb(conv_output_raw, "conv_output_raw", il);
|
||||
|
||||
//ggml_tensor * conv_output = ggml_view_2d(ctx0, conv_output_raw, conv_dim, n_tok, conv_dim * ggml_element_size(conv_output_raw), 0);
|
||||
//ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output);
|
||||
ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_raw);
|
||||
cb(conv_output_silu, "conv_output_silu", il);
|
||||
|
||||
@@ -633,24 +639,27 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
|
||||
|
||||
// Extract the convolved Q, K, V from conv_output
|
||||
ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_output_silu, head_k_dim, num_k_heads, n_tok, 1,
|
||||
ggml_row_size(conv_output_silu->type, head_k_dim), nb1_qkv, nb1_qkv * n_tok, 0);
|
||||
ggml_row_size(conv_output_silu->type, head_k_dim),
|
||||
nb1_qkv, nb1_qkv * n_tok, 0);
|
||||
|
||||
ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_output_silu, head_k_dim, num_k_heads, n_tok, 1,
|
||||
ggml_row_size(conv_output_silu->type, head_k_dim), nb1_qkv, nb1_qkv * n_tok,
|
||||
ggml_row_size(conv_output_silu->type, head_k_dim),
|
||||
nb1_qkv, nb1_qkv * n_tok,
|
||||
head_k_dim * num_k_heads * ggml_element_size(conv_output_silu));
|
||||
|
||||
ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_output_silu, head_v_dim, num_v_heads, n_tok, 1,
|
||||
ggml_row_size(conv_output_silu->type, head_v_dim), nb1_qkv, nb1_qkv * n_tok,
|
||||
ggml_row_size(conv_output_silu->type, head_v_dim),
|
||||
nb1_qkv, nb1_qkv * n_tok,
|
||||
ggml_row_size(conv_output_silu->type, 2 * head_k_dim * num_k_heads));
|
||||
|
||||
cb(q_conv, "q_conv", il);
|
||||
cb(k_conv, "k_conv", il);
|
||||
cb(v_conv, "v_conv", il);
|
||||
|
||||
q_conv = ggml_l2_norm(ctx0, q_conv, hparams.f_norm_rms_eps);
|
||||
k_conv = ggml_l2_norm(ctx0, k_conv, hparams.f_norm_rms_eps);
|
||||
cb(q_conv, "q_conv_normed", il);
|
||||
cb(k_conv, "k_conv_normed", il);
|
||||
const float eps_norm = hparams.f_norm_rms_eps;
|
||||
|
||||
q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm);
|
||||
k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm);
|
||||
|
||||
if (num_k_heads != num_v_heads) {
|
||||
GGML_ASSERT(num_v_heads % num_k_heads == 0);
|
||||
@@ -700,6 +709,9 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
|
||||
ggml_tensor * new_state_flat = ggml_concat(ctx0, new_conv_flat, new_ssm_flat, 0);
|
||||
|
||||
ggml_tensor * state_update = new_state_flat;
|
||||
if (state_dst->type != GGML_TYPE_F32) {
|
||||
state_update = ggml_cast(ctx0, state_update, state_dst->type);
|
||||
}
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, state_update, state_dst));
|
||||
|
||||
ggml_tensor * attn_out_2d = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_tok);
|
||||
@@ -716,8 +728,7 @@ ggml_tensor * delta_net::build_layer_attn_linear_core(ggml_context * ctx0, ggml_
|
||||
ggml_tensor * out = llm_build_context::llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_out, final_output);
|
||||
cb(out, "linear_attn_out", il);
|
||||
|
||||
return out;
|
||||
//return ggml_reshape_2d(ctx0, out, hparams.n_embd, n_tok);
|
||||
return ggml_reshape_2d(ctx0, out, hparams.n_embd, n_tok);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -4394,7 +4394,7 @@ struct llama_context_params llama_context_default_params() {
|
||||
/*.split_mode_graph_scheduling =*/ false,
|
||||
// /*.split_mode_f16 =*/ true,
|
||||
/*.scheduler_async =*/ false,
|
||||
/*.fused_delta_net =*/ 65536,
|
||||
/*.fused_delta_net =*/ 0,
|
||||
/*.mtp =*/ false,
|
||||
/*.mtp_op_type =*/ MTP_OP_NONE,
|
||||
/*.abort_callback =*/ nullptr,
|
||||
|
||||
Reference in New Issue
Block a user