Fused delta net 2 (#1320)

* Revive fused delta-net

* Add command line argument for fused delta net

* Simplify/improve CUDA delta-net

* Add -fdn to llama-bench

* More CUDA fused delta net optimizations

* CPU optimizations

* Much faster fused delta-net on the CPU

It seems it is faster than the chunked implementation!

* Change meaning of fdn from bool flag to threshold value

* Use eps = 1e-6

* Give some nodes a name

* Don't re-apply L2 norm - it has already been done

* This seems quite a bit better

* More tweaks

* Restore per context buffer size log

Not everybody uses models split in 2000 parts, and those who do,
actually want to see the biffer sizes.
This commit is contained in:
Kawrakow
2026-02-26 06:53:43 +01:00
committed by GitHub
parent 87b35dac0c
commit 2616efa296
3 changed files with 41 additions and 78 deletions

View File

@@ -84,73 +84,63 @@ __global__ void delta_net_recurrent_f32(
float * sQ = smem; // HEAD_DIM
float * sK = sQ + HEAD_DIM; // HEAD_DIM
float * sV = sK + HEAD_DIM; // HEAD_DIM
float * sKBeta = sV + HEAD_DIM; // HEAD_DIM (plain k for state update)
float * sVBeta = sKBeta + HEAD_DIM; // HEAD_DIM (v * sigmoid(beta))
float * sOut = sVBeta + HEAD_DIM; // HEAD_DIM
float * sKCumdecay = sOut + HEAD_DIM; // HEAD_DIM (k * sigmoid(beta) * exp(g))
float * sVNew = sKCumdecay + HEAD_DIM; // HEAD_DIM (v_beta - v_prime)
float * sVNew = sV + HEAD_DIM; // HEAD_DIM
const float scale = rsqrtf((float)HEAD_DIM);
__shared__ float sum_helper[block_size/WARP_SIZE];
// Copy initial state to output buffer (will be updated in place)
for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim.x) {
for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += block_size) {
state_dst[i] = state_src[i];
}
__syncthreads();
constexpr int HEAD_DIM_S = HEAD_DIM + 1;
__shared__ float all_sum[2*HEAD_DIM_S*NUM_WARPS];
auto all_sum1 = all_sum;
auto all_sum2 = all_sum1 + HEAD_DIM_S*NUM_WARPS;
// Process each token sequentially
for (int64_t t = 0; t < n_tokens; t++) {
float q_sq = 0.0f;
float k_sq = 0.0f;
for (int i = tid; i < HEAD_DIM; i += blockDim.x) {
sQ[i] = q_ptr[t * qkv_stride_token + i];
float sum_kq = 0.0f;
for (int i = tid; i < HEAD_DIM; i += block_size) {
sQ[i] = q_ptr[t * qkv_stride_token + i] * scale;
sK[i] = k_ptr[t * qkv_stride_token + i];
sV[i] = v_ptr[t * qkv_stride_token + i];
q_sq += sQ[i] * sQ[i];
k_sq += sK[i] * sK[i];
sum_kq += sK[i] * sQ[i];
}
q_sq = reduce_sum<block_size>(q_sq, sum_helper);
k_sq = reduce_sum<block_size>(k_sq, sum_helper);
float q_norm = rsqrtf(q_sq + eps);
float k_norm = rsqrtf(k_sq + eps);
float attn_score = reduce_sum<block_size>(sum_kq, sum_helper);
float beta_val = sigmoid_f(beta_ptr[t]);
float decay = expf(fminf(g_ptr[t], 50.0f));
float sum = 0;
for (int i = tid; i < HEAD_DIM; i += blockDim.x) {
sQ[i] = sQ[i] * q_norm * scale;
sK[i] = sK[i] * k_norm;
sKBeta[i] = sK[i];
sVBeta[i] = sV[i] * beta_val;
sKCumdecay[i] = sK[i] * beta_val * decay;
sum += sK[i] * sQ[i];
}
float attn_score = reduce_sum<block_size>(sum, sum_helper);
//__syncthreads();
for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) {
for (int row_out = lane_id; row_out < HEAD_DIM; row_out += WARP_SIZE) {
float sum1 = 0.0f;
float sum2 = 0.0f;
#pragma unroll
for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) {
for (int col = warp_id; col < HEAD_DIM; col += NUM_WARPS) {
float sval = state_dst[row_out + col * HEAD_DIM];
sum1 += sval * sKCumdecay[col];
sum1 += sval * sK[col];
sum2 += sval * sQ[col];
}
sum1 = warp_reduce_sum(sum1);
sum2 = warp_reduce_sum(sum2);
if (lane_id == 0) {
sVNew[row_out] = sVBeta[row_out] - sum1;
float v_attn = sVNew[row_out] * attn_score;
//sOut[row_out] = sum2 * decay + v_attn;
out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn;
all_sum1[warp_id*HEAD_DIM_S + row_out] = sum1;
all_sum2[warp_id*HEAD_DIM_S + row_out] = sum2;
}
__syncthreads();
for (int row_out = tid; row_out < HEAD_DIM; row_out += block_size) {
float sum1 = all_sum1[row_out];
float sum2 = all_sum2[row_out];
#pragma unroll
for (int i = 1; i < NUM_WARPS; ++i) {
sum1 += all_sum1[row_out + i*HEAD_DIM_S];
sum2 += all_sum2[row_out + i*HEAD_DIM_S];
}
sVNew[row_out] = sV[row_out] * beta_val - sum1 * beta_val * decay;
float v_attn = sVNew[row_out] * attn_score;
out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn;
}
__syncthreads();
@@ -158,19 +148,11 @@ __global__ void delta_net_recurrent_f32(
#pragma unroll
for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) {
float state_val = state_dst[row + out_dim * HEAD_DIM];
float safe_decay = decay;
if (isnan(safe_decay) || isinf(safe_decay)) {
safe_decay = 1.0f;
}
float new_state_val = safe_decay * state_val + sVNew[row] * sKBeta[out_dim];
float new_state_val = decay * state_val + sVNew[row] * sK[out_dim];
new_state_val = fminf(fmaxf(new_state_val, -1e6f), 1e6f);
state_dst[row + out_dim * HEAD_DIM] = new_state_val;
}
}
if (t < n_tokens - 1) {
__syncthreads();
}
}
}
@@ -408,9 +390,7 @@ static void delta_net_f32_cuda(
const int num_blocks = n_seqs * n_heads;
constexpr int threads_per_block = 512; //256;
// Shared memory: 9 * head_dim (for Q, K, V, KBeta, VBeta, Out, KCumdecay, VPrime, VNew)
// Plus 6 floats for Norm[2], g_val, beta_val, decay, attn_score
const size_t smem_size = (9 * head_dim + 6) * sizeof(float);
const size_t smem_size = 4 * head_dim * sizeof(float);
// Use templated kernel for common head dimensions, generic for others
if (head_dim == 64) {
@@ -421,6 +401,7 @@ static void delta_net_f32_cuda(
delta_net_recurrent_f32<128, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
} else {
GGML_ASSERT("Unsupported delta net head size");
delta_net_recurrent_generic_f32<<<num_blocks, threads_per_block, smem_size, stream>>>(
q, k, v, g, beta, state_in, dst, head_dim, n_tokens, n_heads, n_seqs, output_offset, eps);
}

View File

@@ -1397,7 +1397,6 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
static_assert(head_dim % 8 == 0);
#endif
const float eps = 1e-6f;
const float scale = 1.0f / sqrtf((float) head_dim);
float v_new_buf[head_dim];
@@ -1428,42 +1427,25 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
const float g_val = g_data[g_head_offset + t];
const float beta_raw = beta_data[g_head_offset + t];
float q_norm_sq = 0.0f;
float k_norm_sq = 0.0f;
float kq_sum = 0.0f;
#ifdef __AVX2__
auto vqsum = _mm256_setzero_ps();
auto vksum = _mm256_setzero_ps();
auto vqksum = _mm256_setzero_ps();
for (int i = 0; i < head_dim; i += 8) {
auto vq = _mm256_loadu_ps(q_t + i);
auto vk = _mm256_loadu_ps(k_t + i);
vqsum = _mm256_fmadd_ps(vq, vq, vqsum);
vksum = _mm256_fmadd_ps(vk, vk, vksum);
vqksum = _mm256_fmadd_ps(vk, vq, vqksum);
}
q_norm_sq = hsum_float_8(vqsum);
k_norm_sq = hsum_float_8(vksum);
kq_sum = hsum_float_8(vqksum);
kq_sum = hsum_float_8(vqksum);
#else
for (int i = 0; i < head_dim; ++i) {
q_norm_sq += q_t[i] * q_t[i];
k_norm_sq += k_t[i] * k_t[i];
kq_sum += k_t[i] * q_t[i];
kq_sum += k_t[i] * q_t[i];
}
#endif
const float q_norm_inv = 1.0f / sqrtf(q_norm_sq + eps);
const float k_norm_inv = 1.0f / sqrtf(k_norm_sq + eps);
const float beta_val = 1.0f / (1.0f + expf(-beta_raw));
const float decay = expf(fminf(g_val, 50.0f));
float attn_score = kq_sum * k_norm_inv * q_norm_inv * scale;
//float attn_score = 0.0f;
//for (int i = 0; i < head_dim; ++i) {
// attn_score += (k_t[i] * k_norm_inv) * (q_t[i] * q_norm_inv * scale);
//}
float attn_score = kq_sum * scale;
float * out_t = out_data + out_head_offset + t * out_token_stride;
@@ -1479,9 +1461,9 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
}
}
for (int row = 0; row < head_dim; ++row) {
const float v_new = v_t[row] * beta_val - v_prime[row] * beta_val * decay * k_norm_inv;
const float v_new = v_t[row] * beta_val - v_prime[row] * beta_val * decay;
v_new_buf[row] = v_new;
out_t[row] = out_val[row] * decay * q_norm_inv * scale + v_new * attn_score;
out_t[row] = out_val[row] * decay * scale + v_new * attn_score;
}
#ifdef __AVX2__
@@ -1489,7 +1471,7 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
auto vmin = _mm256_set1_ps(-1e6f);
auto vmax = _mm256_set1_ps( 1e6f);
for (int col = 0; col < head_dim; ++col) {
auto vk = _mm256_set1_ps(k_t[col] * k_norm_inv);
auto vk = _mm256_set1_ps(k_t[col]);
for (int row = 0; row < head_dim; row += 8) {
auto vs = _mm256_loadu_ps(state + col * head_dim + row);
auto vn = _mm256_loadu_ps(v_new_buf + row);
@@ -1503,7 +1485,7 @@ void iqk_fused_delta_net_impl(int n_heads, int n_tokens, int n_seqs,
}
#else
for (int col = 0; col < head_dim; ++col) {
const float k_col = k_t[col] * k_norm_inv;
const float k_col = k_t[col];
for (int row = 0; row < head_dim; ++row) {
float s = state[row + col * head_dim];
s = decay * s + v_new_buf[row] * k_col;

View File

@@ -2222,7 +2222,7 @@ static bool llm_load_tensors(
// print memory requirements
for (ggml_backend_buffer_t buf : model.bufs) {
LLAMA_LOG_DEBUG("%s: %10s buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0);
LLAMA_LOG_INFO("%s: %10s buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0);
}
// populate tensors_by_name