From 745dee7d4e65aa60b3830a0f933d205522108649 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 27 Feb 2026 09:37:22 +0000 Subject: [PATCH] Cleanup --- ggml/src/ggml-cuda/delta-net.cu | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-cuda/delta-net.cu b/ggml/src/ggml-cuda/delta-net.cu index 0f8c8878..5146b340 100644 --- a/ggml/src/ggml-cuda/delta-net.cu +++ b/ggml/src/ggml-cuda/delta-net.cu @@ -91,17 +91,15 @@ __global__ void delta_net_recurrent_f32_a( constexpr int num_warps = block_size/WARP_SIZE; const int row = tid % WARP_SIZE; const int col_idx_0 = tid / WARP_SIZE; + const int row_out = row + sub_idx * WARP_SIZE; + // Keep the state in registers, copy the final state to its destination at the end float state_local[HEAD_DIM/num_warps]; for (int i = 0; i < HEAD_DIM/num_warps; ++i) { int col = num_warps*i + col_idx_0; - state_local[i] = state_src[col*HEAD_DIM + row + sub_idx * WARP_SIZE]; + state_local[i] = state_src[col*HEAD_DIM + row_out]; } - //for (int col = col_idx_0; col < HEAD_DIM; col += num_warps) { - // state_dst[col*HEAD_DIM + row + sub_idx * WARP_SIZE] = state_src[col*HEAD_DIM + row + sub_idx * WARP_SIZE]; - //} - constexpr int WARP_SIZE_S = WARP_SIZE + 1; constexpr int num_stored_rows = block_size/WARP_SIZE; __shared__ float all_sum[2*WARP_SIZE_S*num_stored_rows]; @@ -128,11 +126,6 @@ __global__ void delta_net_recurrent_f32_a( sum1 += state_local[i] * sK[col]; sum2 += state_local[i] * sQ[col]; } - //for (int col = col_idx_0; col < HEAD_DIM; col += num_warps) { - // float sval = state_dst[row + sub_idx * WARP_SIZE + col * HEAD_DIM]; - // sum1 += sval * sK[col]; - // sum2 += sval * sQ[col]; - //} all_sum1[col_idx_0*WARP_SIZE_S + row] = sum1; all_sum2[col_idx_0*WARP_SIZE_S + row] = sum2; @@ -144,9 +137,9 @@ __global__ void delta_net_recurrent_f32_a( sum1 += all_sum1[i*WARP_SIZE_S + row]; sum2 += all_sum2[i*WARP_SIZE_S + row]; } - float sv_new = beta_val * (v_ptr[t * qkv_stride_token + row + sub_idx*WARP_SIZE] - sum1 * decay); + float sv_new = beta_val * (v_ptr[t * qkv_stride_token + row_out] - sum1 * decay); if (col_idx_0 == 0) { - out_base[t * out_token_stride + row + sub_idx*WARP_SIZE] = sum2 * decay + sv_new * attn_score; + out_base[t * out_token_stride + row_out] = sum2 * decay + sv_new * attn_score; } for (int i = 0; i < HEAD_DIM/num_warps; ++i) { @@ -156,16 +149,11 @@ __global__ void delta_net_recurrent_f32_a( state_local[i] = new_state_val; } - //for (int col = col_idx_0; col < HEAD_DIM; col += num_warps) { - // float state_val = state_dst[row + sub_idx*WARP_SIZE + col * HEAD_DIM]; - // float new_state_val = decay * state_val + sv_new * sK[col]; - // new_state_val = fminf(fmaxf(new_state_val, -1e6f), 1e6f); - // state_dst[row + sub_idx*WARP_SIZE + col * HEAD_DIM] = new_state_val; - //} } + // Copy the final state to its destination for (int i = 0; i < HEAD_DIM/num_warps; ++i) { int col = num_warps*i + col_idx_0; - state_dst[col*HEAD_DIM + row + sub_idx * WARP_SIZE] = state_local[i]; + state_dst[col*HEAD_DIM + row_out] = state_local[i]; } }