From facc8fdc4446b7dc72932a6f927adcd88a84b00d Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 27 Feb 2026 07:03:09 +0100 Subject: [PATCH] Very slightly better fused delta-net (#1330) --- ggml/src/ggml-cuda/delta-net.cu | 74 +++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 22 deletions(-) diff --git a/ggml/src/ggml-cuda/delta-net.cu b/ggml/src/ggml-cuda/delta-net.cu index 98951904..acb6f6c4 100644 --- a/ggml/src/ggml-cuda/delta-net.cu +++ b/ggml/src/ggml-cuda/delta-net.cu @@ -96,9 +96,10 @@ __global__ void delta_net_recurrent_f32( } constexpr int HEAD_DIM_S = HEAD_DIM + 1; - __shared__ float all_sum[2*HEAD_DIM_S*NUM_WARPS]; + constexpr int num_stored_rows = block_size >= HEAD_DIM && block_size % HEAD_DIM == 0 ? block_size/HEAD_DIM : NUM_WARPS; + __shared__ float all_sum[2*HEAD_DIM_S*num_stored_rows]; auto all_sum1 = all_sum; - auto all_sum2 = all_sum1 + HEAD_DIM_S*NUM_WARPS; + auto all_sum2 = all_sum1 + HEAD_DIM_S*num_stored_rows; // Process each token sequentially for (int64_t t = 0; t < n_tokens; t++) { @@ -116,39 +117,68 @@ __global__ void delta_net_recurrent_f32( float beta_val = sigmoid_f(beta_ptr[t]); float decay = expf(fminf(g_ptr[t], 50.0f)); - for (int row_out = lane_id; row_out < HEAD_DIM; row_out += WARP_SIZE) { - float sum1 = 0.0f; - float sum2 = 0.0f; + if constexpr (block_size >= HEAD_DIM && block_size % HEAD_DIM == 0) { + int idx = tid / HEAD_DIM; + int row_out = tid % HEAD_DIM; + float sum1 = 0, sum2 = 0; #pragma unroll - for (int col = warp_id; col < HEAD_DIM; col += NUM_WARPS) { + for (int col = idx; col < HEAD_DIM; col += block_size/HEAD_DIM) { float sval = state_dst[row_out + col * HEAD_DIM]; sum1 += sval * sK[col]; sum2 += sval * sQ[col]; } - all_sum1[warp_id*HEAD_DIM_S + row_out] = sum1; - all_sum2[warp_id*HEAD_DIM_S + row_out] = sum2; - } - __syncthreads(); + all_sum1[idx*HEAD_DIM_S + row_out] = sum1; + all_sum2[idx*HEAD_DIM_S + row_out] = sum2; - for (int row_out = tid; row_out < HEAD_DIM; row_out += block_size) { - float sum1 = all_sum1[row_out]; - float sum2 = all_sum2[row_out]; - #pragma unroll - for (int i = 1; i < NUM_WARPS; ++i) { - sum1 += all_sum1[row_out + i*HEAD_DIM_S]; - sum2 += all_sum2[row_out + i*HEAD_DIM_S]; + __syncthreads(); + + if (idx == 0) { + #pragma unroll + for (int i = 1; i < block_size/HEAD_DIM; ++i) { + sum1 += all_sum1[i*HEAD_DIM_S + row_out]; + sum2 += all_sum2[i*HEAD_DIM_S + row_out]; + } + sVNew[row_out] = sV[row_out] * beta_val - sum1 * beta_val * decay; + float v_attn = sVNew[row_out] * attn_score; + out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn; } - sVNew[row_out] = sV[row_out] * beta_val - sum1 * beta_val * decay; - float v_attn = sVNew[row_out] * attn_score; - out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn; + __syncthreads(); + } else { + for (int row_out = lane_id; row_out < HEAD_DIM; row_out += WARP_SIZE) { + float sum1 = 0.0f; + float sum2 = 0.0f; + #pragma unroll + for (int col = warp_id; col < HEAD_DIM; col += NUM_WARPS) { + float sval = state_dst[row_out + col * HEAD_DIM]; + sum1 += sval * sK[col]; + sum2 += sval * sQ[col]; + } + all_sum1[warp_id*HEAD_DIM_S + row_out] = sum1; + all_sum2[warp_id*HEAD_DIM_S + row_out] = sum2; + } + __syncthreads(); + + for (int row_out = tid; row_out < HEAD_DIM; row_out += block_size) { + float sum1 = all_sum1[row_out]; + float sum2 = all_sum2[row_out]; + #pragma unroll + for (int i = 1; i < NUM_WARPS; ++i) { + sum1 += all_sum1[row_out + i*HEAD_DIM_S]; + sum2 += all_sum2[row_out + i*HEAD_DIM_S]; + } + sVNew[row_out] = sV[row_out] * beta_val - sum1 * beta_val * decay; + float v_attn = sVNew[row_out] * attn_score; + out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn; + } + __syncthreads(); } - __syncthreads(); for (int out_dim = warp_id; out_dim < HEAD_DIM; out_dim += NUM_WARPS) { + float k_col = sK[out_dim]; #pragma unroll for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) { float state_val = state_dst[row + out_dim * HEAD_DIM]; - float new_state_val = decay * state_val + sVNew[row] * sK[out_dim]; + float new_state_val = decay * state_val + sVNew[row] * k_col; //sK[out_dim]; new_state_val = fminf(fmaxf(new_state_val, -1e6f), 1e6f); state_dst[row + out_dim * HEAD_DIM] = new_state_val; }