From 7af6892dcfa8d6aa08c994aab706aa5d05c766ba Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 24 Feb 2026 10:00:03 +0000 Subject: [PATCH] More CUDA fused delta net optimizations --- ggml/src/ggml-cuda/delta-net.cu | 59 +++++++++++++-------------------- 1 file changed, 23 insertions(+), 36 deletions(-) diff --git a/ggml/src/ggml-cuda/delta-net.cu b/ggml/src/ggml-cuda/delta-net.cu index 40b62e49..367ae67d 100644 --- a/ggml/src/ggml-cuda/delta-net.cu +++ b/ggml/src/ggml-cuda/delta-net.cu @@ -42,11 +42,11 @@ __global__ void delta_net_recurrent_f32( const int64_t output_offset, // offset where state starts in output const float eps) { const int batch_idx = blockIdx.x / n_heads; - const int head_idx = blockIdx.x % n_heads; + const int head_idx = blockIdx.x % n_heads; const int tid = threadIdx.x; const int warp_id = tid / WARP_SIZE; // 0-7 for 256 threads const int lane_id = tid % WARP_SIZE; // 0-31 - constexpr int NUM_WARPS = block_size/WARP_SIZE; // 256 / 32 + constexpr int NUM_WARPS = block_size/WARP_SIZE; // Strides for input tensors (column-major) // Q/K/V: [HEAD_DIM, n_tokens, n_heads, n_seqs] @@ -88,8 +88,7 @@ __global__ void delta_net_recurrent_f32( float * sVBeta = sKBeta + HEAD_DIM; // HEAD_DIM (v * sigmoid(beta)) float * sOut = sVBeta + HEAD_DIM; // HEAD_DIM float * sKCumdecay = sOut + HEAD_DIM; // HEAD_DIM (k * sigmoid(beta) * exp(g)) - float * sVPrime = sKCumdecay + HEAD_DIM; // HEAD_DIM (state @ k_cumdecay) - float * sVNew = sVPrime + HEAD_DIM; // HEAD_DIM (v_beta - v_prime) + float * sVNew = sKCumdecay + HEAD_DIM; // HEAD_DIM (v_beta - v_prime) const float scale = rsqrtf((float)HEAD_DIM); @@ -123,52 +122,41 @@ __global__ void delta_net_recurrent_f32( float beta_val = sigmoid_f(beta_ptr[t]); float decay = expf(fminf(g_ptr[t], 50.0f)); + float sum = 0; for (int i = tid; i < HEAD_DIM; i += blockDim.x) { sQ[i] = sQ[i] * q_norm * scale; sK[i] = sK[i] * k_norm; sKBeta[i] = sK[i]; sVBeta[i] = sV[i] * beta_val; sKCumdecay[i] = sK[i] * beta_val * decay; - } - __syncthreads(); - - for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) { - float sum = 0.0f; - #pragma unroll - for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) { - sum += state_dst[row_out + col * HEAD_DIM] * sKCumdecay[col]; - } - sum = warp_reduce_sum(sum); - if (lane_id == 0) { - sVPrime[row_out] = sum; - } - } - __syncthreads(); - - float sum = 0; - for (int i = tid; i < HEAD_DIM; i += block_size) { - sVNew[i] = sVBeta[i] - sVPrime[i]; sum += sK[i] * sQ[i]; } float attn_score = reduce_sum(sum, sum_helper); + //__syncthreads(); for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) { - float sum = 0.0f; + float sum1 = 0.0f; + float sum2 = 0.0f; #pragma unroll for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) { - float state_val = state_dst[row_out + col * HEAD_DIM]; - sum += sQ[col] * decay * state_val; + float sval = state_dst[row_out + col * HEAD_DIM]; + sum1 += sval * sKCumdecay[col]; + sum2 += sval * sQ[col]; } - sum = warp_reduce_sum(sum); + sum1 = warp_reduce_sum(sum1); + sum2 = warp_reduce_sum(sum2); if (lane_id == 0) { + sVNew[row_out] = sVBeta[row_out] - sum1; float v_attn = sVNew[row_out] * attn_score; - sOut[row_out] = sum + v_attn; + //sOut[row_out] = sum2 * decay + v_attn; + out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn; } } __syncthreads(); - for (int out_dim = tid; out_dim < HEAD_DIM; out_dim += block_size) { - for (int row = 0; row < HEAD_DIM; row++) { + for (int out_dim = warp_id; out_dim < HEAD_DIM; out_dim += NUM_WARPS) { + #pragma unroll + for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) { float state_val = state_dst[row + out_dim * HEAD_DIM]; float safe_decay = decay; if (isnan(safe_decay) || isinf(safe_decay)) { @@ -179,12 +167,10 @@ __global__ void delta_net_recurrent_f32( state_dst[row + out_dim * HEAD_DIM] = new_state_val; } } - __syncthreads(); - - for (int i = tid; i < HEAD_DIM; i += block_size) { - out_base[t * out_token_stride + i] = sOut[i]; + if (t < n_tokens - 1) { + __syncthreads(); } - __syncthreads(); + } } @@ -420,7 +406,7 @@ static void delta_net_f32_cuda( // One block per (batch, head) pair const int num_blocks = n_seqs * n_heads; - constexpr int threads_per_block = 256; + constexpr int threads_per_block = 512; //256; // Shared memory: 9 * head_dim (for Q, K, V, KBeta, VBeta, Out, KCumdecay, VPrime, VNew) // Plus 6 floats for Norm[2], g_val, beta_val, decay, attn_score @@ -431,6 +417,7 @@ static void delta_net_f32_cuda( delta_net_recurrent_f32<64, threads_per_block><<>>( q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps); } else if (head_dim == 128) { + GGML_ASSERT(num_blocks % 8 == 0); delta_net_recurrent_f32<128, threads_per_block><<>>( q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps); } else {