From a8ef7e20e7924ceda9bfcbcb935372845943c64e Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 25 Feb 2026 13:11:47 +0000 Subject: [PATCH] More tweaks --- ggml/src/ggml-cuda/delta-net.cu | 54 ++++++++------------------------- 1 file changed, 13 insertions(+), 41 deletions(-) diff --git a/ggml/src/ggml-cuda/delta-net.cu b/ggml/src/ggml-cuda/delta-net.cu index ffefa894..98951904 100644 --- a/ggml/src/ggml-cuda/delta-net.cu +++ b/ggml/src/ggml-cuda/delta-net.cu @@ -95,28 +95,27 @@ __global__ void delta_net_recurrent_f32( state_dst[i] = state_src[i]; } - __shared__ float all_sum[2*HEAD_DIM*NUM_WARPS]; + constexpr int HEAD_DIM_S = HEAD_DIM + 1; + __shared__ float all_sum[2*HEAD_DIM_S*NUM_WARPS]; auto all_sum1 = all_sum; - auto all_sum2 = all_sum1 + HEAD_DIM*NUM_WARPS; + auto all_sum2 = all_sum1 + HEAD_DIM_S*NUM_WARPS; // Process each token sequentially for (int64_t t = 0; t < n_tokens; t++) { float sum_kq = 0.0f; for (int i = tid; i < HEAD_DIM; i += block_size) { - sQ[i] = q_ptr[t * qkv_stride_token + i]; + sQ[i] = q_ptr[t * qkv_stride_token + i] * scale; sK[i] = k_ptr[t * qkv_stride_token + i]; sV[i] = v_ptr[t * qkv_stride_token + i]; sum_kq += sK[i] * sQ[i]; } - sum_kq = reduce_sum(sum_kq, sum_helper); + float attn_score = reduce_sum(sum_kq, sum_helper); float beta_val = sigmoid_f(beta_ptr[t]); float decay = expf(fminf(g_ptr[t], 50.0f)); - float attn_score = sum_kq * scale; - for (int row_out = lane_id; row_out < HEAD_DIM; row_out += WARP_SIZE) { float sum1 = 0.0f; float sum2 = 0.0f; @@ -126,45 +125,25 @@ __global__ void delta_net_recurrent_f32( sum1 += sval * sK[col]; sum2 += sval * sQ[col]; } - all_sum1[warp_id*HEAD_DIM + row_out] = sum1; - all_sum2[warp_id*HEAD_DIM + row_out] = sum2; + all_sum1[warp_id*HEAD_DIM_S + row_out] = sum1; + all_sum2[warp_id*HEAD_DIM_S + row_out] = sum2; } __syncthreads(); for (int row_out = tid; row_out < HEAD_DIM; row_out += block_size) { float sum1 = all_sum1[row_out]; float sum2 = all_sum2[row_out]; + #pragma unroll for (int i = 1; i < NUM_WARPS; ++i) { - sum1 += all_sum1[row_out + i*HEAD_DIM]; - sum2 += all_sum2[row_out + i*HEAD_DIM]; + sum1 += all_sum1[row_out + i*HEAD_DIM_S]; + sum2 += all_sum2[row_out + i*HEAD_DIM_S]; } - sum1 *= beta_val * decay; - sum2 *= scale * decay; - sVNew[row_out] = sV[row_out] * beta_val - sum1; + sVNew[row_out] = sV[row_out] * beta_val - sum1 * beta_val * decay; float v_attn = sVNew[row_out] * attn_score; - out_base[t * out_token_stride + row_out] = sum2 + v_attn; + out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn; } __syncthreads(); - //for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) { - // float sum1 = 0.0f; - // float sum2 = 0.0f; - // #pragma unroll - // for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) { - // float sval = state_dst[row_out + col * HEAD_DIM]; - // sum1 += sval * sK[col]; - // sum2 += sval * sQ[col]; - // } - // sum1 = warp_reduce_sum(sum1) * beta_val * decay; - // sum2 = warp_reduce_sum(sum2) * scale * decay; - // if (lane_id == 0) { - // sVNew[row_out] = sV[row_out] * beta_val - sum1; - // float v_attn = sVNew[row_out] * attn_score; - // out_base[t * out_token_stride + row_out] = sum2 + v_attn; - // } - //} - //__syncthreads(); - for (int out_dim = warp_id; out_dim < HEAD_DIM; out_dim += NUM_WARPS) { #pragma unroll for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) { @@ -174,10 +153,6 @@ __global__ void delta_net_recurrent_f32( state_dst[row + out_dim * HEAD_DIM] = new_state_val; } } - //if (t < n_tokens - 1) { - // __syncthreads(); - //} - } } @@ -415,10 +390,6 @@ static void delta_net_f32_cuda( const int num_blocks = n_seqs * n_heads; constexpr int threads_per_block = 512; //256; - // Shared memory: 9 * head_dim (for Q, K, V, KBeta, VBeta, Out, KCumdecay, VPrime, VNew) - // Plus 6 floats for Norm[2], g_val, beta_val, decay, attn_score - //const size_t smem_size = (9 * head_dim + 6) * sizeof(float); - //const size_t smem_size = (4 * head_dim + 2 * n_tokens) * sizeof(float); const size_t smem_size = 4 * head_dim * sizeof(float); // Use templated kernel for common head dimensions, generic for others @@ -430,6 +401,7 @@ static void delta_net_f32_cuda( delta_net_recurrent_f32<128, threads_per_block><<>>( q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps); } else { + GGML_ASSERT("Unsupported delta net head size"); delta_net_recurrent_generic_f32<<>>( q, k, v, g, beta, state_in, dst, head_dim, n_tokens, n_heads, n_seqs, output_offset, eps); }