From dc44a37ca257410809cdbe55a65f78194d722ddd Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Tue, 24 Feb 2026 07:41:09 +0000 Subject: [PATCH] Simplify/improve CUDA delta-net --- ggml/src/ggml-cuda/delta-net.cu | 1306 ++----------------------------- src/llama-delta-net.cpp | 17 +- 2 files changed, 65 insertions(+), 1258 deletions(-) diff --git a/ggml/src/ggml-cuda/delta-net.cu b/ggml/src/ggml-cuda/delta-net.cu index 8291a9f8..40b62e49 100644 --- a/ggml/src/ggml-cuda/delta-net.cu +++ b/ggml/src/ggml-cuda/delta-net.cu @@ -10,10 +10,24 @@ __device__ __forceinline__ float sigmoid_f(float x) { return 1.0f / (1.0f + expf(-x)); } -// Token-by-token recurrent kernel -// One block per (batch, head) pair, processes all tokens sequentially -// State is kept in global memory (too large for shared memory at HEAD_DIM=128) -template +template +__device__ __forceinline__ float reduce_sum(float x, float * s) { + x = warp_reduce_sum(x); + if constexpr (block_size > WARP_SIZE) { + //__shared__ float s[block_size/WARP_SIZE]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s[warp_id] = x; + } + __syncthreads(); + x = lane_id < block_size/WARP_SIZE ? s[lane_id] : 0.0f; + x = warp_reduce_sum(x); + } + return x; +} + +template __global__ void delta_net_recurrent_f32( const float * __restrict__ q, // [HEAD_DIM, n_tokens, n_heads, n_seqs] const float * __restrict__ k, // [HEAD_DIM, n_tokens, n_heads, n_seqs] @@ -22,18 +36,17 @@ __global__ void delta_net_recurrent_f32( const float * __restrict__ beta_in, // [1, n_tokens, n_heads, n_seqs] const float * __restrict__ state_in, // [HEAD_DIM, HEAD_DIM*n_heads, 1, n_seqs] float * __restrict__ dst, // output + new_state concatenated - const int64_t n_tokens, const int64_t n_heads, + const int64_t n_tokens, const int64_t n_seqs, const int64_t output_offset, // offset where state starts in output - const float eps) -{ + const float eps) { const int batch_idx = blockIdx.x / n_heads; const int head_idx = blockIdx.x % n_heads; const int tid = threadIdx.x; const int warp_id = tid / WARP_SIZE; // 0-7 for 256 threads const int lane_id = tid % WARP_SIZE; // 0-31 - constexpr int NUM_WARPS = 8; // 256 / 32 + constexpr int NUM_WARPS = block_size/WARP_SIZE; // 256 / 32 // Strides for input tensors (column-major) // Q/K/V: [HEAD_DIM, n_tokens, n_heads, n_seqs] @@ -66,9 +79,6 @@ __global__ void delta_net_recurrent_f32( const int64_t out_token_stride = HEAD_DIM * n_heads; // stride between tokens float * state_dst = dst + output_offset + batch_idx * state_batch_stride + state_head_offset; - // Shared memory for scalars (moved outside loop for clarity) - __shared__ float shared_g_val, shared_beta_val, shared_decay, shared_attn_score; - // Shared memory for current token's Q, K, V (normalized), and intermediate results extern __shared__ float smem[]; float * sQ = smem; // HEAD_DIM @@ -80,75 +90,42 @@ __global__ void delta_net_recurrent_f32( float * sKCumdecay = sOut + HEAD_DIM; // HEAD_DIM (k * sigmoid(beta) * exp(g)) float * sVPrime = sKCumdecay + HEAD_DIM; // HEAD_DIM (state @ k_cumdecay) float * sVNew = sVPrime + HEAD_DIM; // HEAD_DIM (v_beta - v_prime) - float * sNorm = sVNew + HEAD_DIM; // 2 (for Q and K norms) const float scale = rsqrtf((float)HEAD_DIM); + __shared__ float sum_helper[block_size/WARP_SIZE]; + // Copy initial state to output buffer (will be updated in place) for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim.x) { - int col = i / HEAD_DIM; - int row = i % HEAD_DIM; - // Column-major: state[row, col] at index row + col*HEAD_DIM - state_dst[row + col * HEAD_DIM] = state_src[row + col * HEAD_DIM]; + state_dst[i] = state_src[i]; } __syncthreads(); // Process each token sequentially for (int64_t t = 0; t < n_tokens; t++) { - // Reset norm accumulators - if (tid < 2) { - sNorm[tid] = 0.0f; - } - __syncthreads(); + float q_sq = 0.0f; + float k_sq = 0.0f; for (int i = tid; i < HEAD_DIM; i += blockDim.x) { sQ[i] = q_ptr[t * qkv_stride_token + i]; sK[i] = k_ptr[t * qkv_stride_token + i]; sV[i] = v_ptr[t * qkv_stride_token + i]; - } - __syncthreads(); - - float q_sq_local = 0.0f; - float k_sq_local = 0.0f; - for (int i = tid; i < HEAD_DIM; i += blockDim.x) { - q_sq_local += sQ[i] * sQ[i]; - k_sq_local += sK[i] * sK[i]; + q_sq += sQ[i] * sQ[i]; + k_sq += sK[i] * sK[i]; } - // Warp reduction - #pragma unroll - for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { - q_sq_local += __shfl_xor_sync(0xffffffff, q_sq_local, offset); - k_sq_local += __shfl_xor_sync(0xffffffff, k_sq_local, offset); - } + q_sq = reduce_sum(q_sq, sum_helper); + k_sq = reduce_sum(k_sq, sum_helper); - // Cross-warp reduction using shared memory atomics - if (tid % WARP_SIZE == 0) { - atomicAdd(&sNorm[0], q_sq_local); - atomicAdd(&sNorm[1], k_sq_local); - } - __syncthreads(); + float q_norm = rsqrtf(q_sq + eps); + float k_norm = rsqrtf(k_sq + eps); - float q_norm = rsqrtf(sNorm[0] + eps); - float k_norm = rsqrtf(sNorm[1] + eps); + float beta_val = sigmoid_f(beta_ptr[t]); + float decay = expf(fminf(g_ptr[t], 50.0f)); for (int i = tid; i < HEAD_DIM; i += blockDim.x) { sQ[i] = sQ[i] * q_norm * scale; sK[i] = sK[i] * k_norm; - } - __syncthreads(); - - if (tid == 0) { - shared_g_val = g_ptr[t]; - shared_beta_val = sigmoid_f(beta_ptr[t]); - shared_decay = expf(fminf(shared_g_val, 50.0f)); - } - __syncthreads(); - - float beta_val = shared_beta_val; - float decay = shared_decay; - - for (int i = tid; i < HEAD_DIM; i += blockDim.x) { sKBeta[i] = sK[i]; sVBeta[i] = sV[i] * beta_val; sKCumdecay[i] = sK[i] * beta_val * decay; @@ -157,61 +134,40 @@ __global__ void delta_net_recurrent_f32( for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) { float sum = 0.0f; - #pragma unroll 4 + #pragma unroll for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) { sum += state_dst[row_out + col * HEAD_DIM] * sKCumdecay[col]; } - #pragma unroll - for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { - sum += __shfl_xor_sync(0xffffffff, sum, offset); - } + sum = warp_reduce_sum(sum); if (lane_id == 0) { sVPrime[row_out] = sum; } } __syncthreads(); - for (int i = tid; i < HEAD_DIM; i += blockDim.x) { + float sum = 0; + for (int i = tid; i < HEAD_DIM; i += block_size) { sVNew[i] = sVBeta[i] - sVPrime[i]; + sum += sK[i] * sQ[i]; } - __syncthreads(); - - if (warp_id == 0) { - float sum = 0.0f; - #pragma unroll 4 - for (int i = lane_id; i < HEAD_DIM; i += WARP_SIZE) { - sum += sK[i] * sQ[i]; - } - // Warp reduction - #pragma unroll - for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { - sum += __shfl_xor_sync(0xffffffff, sum, offset); - } - if (lane_id == 0) { - shared_attn_score = sum; - } - } - __syncthreads(); + float attn_score = reduce_sum(sum, sum_helper); for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) { float sum = 0.0f; - #pragma unroll 4 + #pragma unroll for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) { float state_val = state_dst[row_out + col * HEAD_DIM]; sum += sQ[col] * decay * state_val; } - #pragma unroll - for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { - sum += __shfl_xor_sync(0xffffffff, sum, offset); - } + sum = warp_reduce_sum(sum); if (lane_id == 0) { - float v_attn = sVNew[row_out] * shared_attn_score; + float v_attn = sVNew[row_out] * attn_score; sOut[row_out] = sum + v_attn; } } __syncthreads(); - for (int out_dim = tid; out_dim < HEAD_DIM; out_dim += blockDim.x) { + for (int out_dim = tid; out_dim < HEAD_DIM; out_dim += block_size) { for (int row = 0; row < HEAD_DIM; row++) { float state_val = state_dst[row + out_dim * HEAD_DIM]; float safe_decay = decay; @@ -225,7 +181,7 @@ __global__ void delta_net_recurrent_f32( } __syncthreads(); - for (int i = tid; i < HEAD_DIM; i += blockDim.x) { + for (int i = tid; i < HEAD_DIM; i += block_size) { out_base[t * out_token_stride + i] = sOut[i]; } __syncthreads(); @@ -246,8 +202,7 @@ __global__ void delta_net_recurrent_generic_f32( const int64_t n_heads, const int64_t n_seqs, const int64_t output_offset, - const float eps) -{ + const float eps) { const int batch_idx = blockIdx.x / n_heads; const int head_idx = blockIdx.x % n_heads; const int tid = threadIdx.x; @@ -442,975 +397,6 @@ __global__ void delta_net_recurrent_generic_f32( } } -// FP16 DeltaNet kernel using __hfma2 for 2x throughput -#if !defined(GGML_USE_HIP) -template -__global__ void delta_net_fp16_optimized( - const float * __restrict__ q, - const float * __restrict__ k, - const float * __restrict__ v, - const float * __restrict__ g, - const float * __restrict__ beta_in, - const float * __restrict__ state_in, - float * __restrict__ dst, - const int64_t n_tokens, - const int64_t n_heads, - const int64_t n_seqs, - const int64_t output_offset, - const float eps) -{ - static_assert(HEAD_DIM == 128, "FP16 kernel requires HEAD_DIM=128"); - static_assert(HEAD_DIM % 2 == 0, "HEAD_DIM must be even for half2"); - - const int batch_idx = blockIdx.x / n_heads; - const int head_idx = blockIdx.x % n_heads; - const int tid = threadIdx.x; - const int warp_id = tid / WARP_SIZE; - const int lane_id = tid % WARP_SIZE; - constexpr int NUM_WARPS = 8; // 256 threads / 32 - - // Strides (column-major) - const int64_t qkv_stride_token = HEAD_DIM; - const int64_t qkv_stride_head = HEAD_DIM * n_tokens; - const int64_t qkv_stride_batch = HEAD_DIM * n_tokens * n_heads; - const int64_t g_stride_head = n_tokens; - const int64_t g_stride_batch = n_tokens * n_heads; - const int64_t state_head_offset = head_idx * HEAD_DIM * HEAD_DIM; - const int64_t state_batch_stride = HEAD_DIM * HEAD_DIM * n_heads; - - // Pointers - const float * q_ptr = q + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; - const float * k_ptr = k + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; - const float * v_ptr = v + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; - const float * g_ptr = g + batch_idx * g_stride_batch + head_idx * g_stride_head; - const float * beta_ptr = beta_in + batch_idx * g_stride_batch + head_idx * g_stride_head; - const float * state_src = state_in + batch_idx * state_batch_stride + state_head_offset; - float * out_base = dst + batch_idx * (HEAD_DIM * n_heads * n_tokens) + head_idx * HEAD_DIM; - const int64_t out_token_stride = HEAD_DIM * n_heads; - float * state_dst = dst + output_offset + batch_idx * state_batch_stride + state_head_offset; - - // Shared memory layout: - // - FP16 state COLUMN-MAJOR: 128×128 = 16384 half = 32KB - // - FP16 vectors: K, KCumdecay, Q_scaled = 3 × 128 = 384 half = 768 bytes - // - FP32 vectors: V, KBeta, VBeta, Out, VPrime, VNew = 6 × 128 = 768 floats = 3KB - // Total: ~36KB - - extern __shared__ char smem_raw[]; - - // FP16 state COLUMN-MAJOR: state[row, col] = state_smem[row + col * HEAD_DIM] - half * state_smem = (half *)smem_raw; - - // FP16 vectors - half * sK_fp16 = (half *)(smem_raw + HEAD_DIM * HEAD_DIM * sizeof(half)); - half * sKCumdecay_fp16 = sK_fp16 + HEAD_DIM; - half * sQ_fp16 = sKCumdecay_fp16 + HEAD_DIM; - - // FP32 vectors - float * sV = (float *)(sQ_fp16 + HEAD_DIM); - float * sKBeta = sV + HEAD_DIM; - float * sVBeta = sKBeta + HEAD_DIM; - float * sOut = sVBeta + HEAD_DIM; - float * sVPrime = sOut + HEAD_DIM; - float * sVNew = sVPrime + HEAD_DIM; - float * sNorm = sVNew + HEAD_DIM; - - __shared__ float shared_decay, shared_attn_score; - - const float scale = rsqrtf((float)HEAD_DIM); - - // Load initial state DIRECTLY (no transpose - same layout as global) - // state[row, col] = state_smem[row + col * HEAD_DIM] - for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim.x) { - state_smem[i] = __float2half(state_src[i]); - } - __syncthreads(); - - // Process each token - for (int64_t t = 0; t < n_tokens; t++) { - // Reset norms - if (tid < 2) { - sNorm[tid] = 0.0f; - } - __syncthreads(); - - // 1. Load Q, K, V and compute norms - float q_sq_local = 0.0f, k_sq_local = 0.0f; - for (int i = tid; i < HEAD_DIM; i += blockDim.x) { - float q_val = q_ptr[t * qkv_stride_token + i]; - float k_val = k_ptr[t * qkv_stride_token + i]; - sV[i] = v_ptr[t * qkv_stride_token + i]; - q_sq_local += q_val * q_val; - k_sq_local += k_val * k_val; - sVPrime[i] = q_val; // Temp storage for Q - sVNew[i] = k_val; // Temp storage for K - } - - // Warp reduction for norms - #pragma unroll - for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { - q_sq_local += __shfl_xor_sync(0xffffffff, q_sq_local, offset); - k_sq_local += __shfl_xor_sync(0xffffffff, k_sq_local, offset); - } - if (lane_id == 0) { - atomicAdd(&sNorm[0], q_sq_local); - atomicAdd(&sNorm[1], k_sq_local); - } - __syncthreads(); - - float q_norm = rsqrtf(sNorm[0] + eps); - float k_norm = rsqrtf(sNorm[1] + eps); - - // 2. Load g and beta, compute decay - if (tid == 0) { - shared_decay = expf(fminf(g_ptr[t], 50.0f)); // Clamp g to prevent overflow - } - __syncthreads(); - float decay = shared_decay; - float beta_val = sigmoid_f(beta_ptr[t]); - - // 3. Compute normalized vectors and convert to FP16 - for (int i = tid; i < HEAD_DIM; i += blockDim.x) { - float q_normalized = sVPrime[i] * q_norm * scale; - float k_normalized = sVNew[i] * k_norm; - - sQ_fp16[i] = __float2half(q_normalized * decay); - sK_fp16[i] = __float2half(k_normalized); - sKCumdecay_fp16[i] = __float2half(k_normalized * beta_val * decay); - - sKBeta[i] = k_normalized; - sVBeta[i] = sV[i] * beta_val; - } - __syncthreads(); - - // 4. v_prime = state @ k_cumdecay using half2 - // Column-major: state[row, col] = state_smem[row + col * HEAD_DIM] - // v_prime[col] = sum_row(state[row, col] * k_cumdecay[row]) - // For fixed col, state[0,col], state[1,col], ... = state_smem[col*128], state_smem[col*128+1], ... - // These ARE contiguous! Can use half2. - for (int col = warp_id; col < HEAD_DIM; col += NUM_WARPS) { - half2 sum_h2 = __float2half2_rn(0.0f); - half2 * state_col = (half2 *)(&state_smem[col * HEAD_DIM]); - half2 * vec_h2 = (half2 *)sKCumdecay_fp16; - - #pragma unroll 2 - for (int row = lane_id; row < HEAD_DIM / 2; row += WARP_SIZE) { - sum_h2 = __hfma2(state_col[row], vec_h2[row], sum_h2); - } - - float sum = __half2float(sum_h2.x) + __half2float(sum_h2.y); - #pragma unroll - for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { - sum += __shfl_xor_sync(0xffffffff, sum, offset); - } - - if (lane_id == 0) { - sVPrime[col] = sum; - } - } - __syncthreads(); - - // 5. v_new = v_beta - v_prime - for (int i = tid; i < HEAD_DIM; i += blockDim.x) { - sVNew[i] = sVBeta[i] - sVPrime[i]; - } - __syncthreads(); - - // 6. attn_score = dot(k, q) in FP32 - if (warp_id == 0) { - float sum = 0.0f; - for (int i = lane_id; i < HEAD_DIM; i += WARP_SIZE) { - sum += sKBeta[i] * __half2float(sQ_fp16[i]) / decay; - } - #pragma unroll - for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { - sum += __shfl_xor_sync(0xffffffff, sum, offset); - } - if (lane_id == 0) { - shared_attn_score = sum; - } - } - __syncthreads(); - - // 7. output = attn_inter + v_attn - // attn_inter[col] = sum_row(state[row, col] * q_scaled[row]) - // Same pattern as v_prime - columns are contiguous! - for (int col = warp_id; col < HEAD_DIM; col += NUM_WARPS) { - half2 sum_h2 = __float2half2_rn(0.0f); - half2 * state_col = (half2 *)(&state_smem[col * HEAD_DIM]); - half2 * vec_h2 = (half2 *)sQ_fp16; - - #pragma unroll 2 - for (int row = lane_id; row < HEAD_DIM / 2; row += WARP_SIZE) { - sum_h2 = __hfma2(state_col[row], vec_h2[row], sum_h2); - } - - float sum = __half2float(sum_h2.x) + __half2float(sum_h2.y); - #pragma unroll - for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { - sum += __shfl_xor_sync(0xffffffff, sum, offset); - } - - if (lane_id == 0) { - float v_attn = sVNew[col] * shared_attn_score; - sOut[col] = sum + v_attn; - } - } - __syncthreads(); - - // 8. Update state: state_new = decay * state + outer(k, v_new) - // state[row, col] = decay * state[row, col] + k[row] * v_new[col] - half decay_h = __float2half(fminf(fmaxf(decay, 0.0f), 10.0f)); - - for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim.x) { - int col = i / HEAD_DIM; - int row = i % HEAD_DIM; - - half state_val = state_smem[row + col * HEAD_DIM]; - half k_val = sK_fp16[row]; - half v_new_h = __float2half(sVNew[col]); - - half new_val = __hfma(decay_h, state_val, __hmul(k_val, v_new_h)); - - float new_val_f = __half2float(new_val); - new_val_f = fminf(fmaxf(new_val_f, -1e4f), 1e4f); - state_smem[row + col * HEAD_DIM] = __float2half(new_val_f); - } - __syncthreads(); - - // 9. Write output - for (int i = tid; i < HEAD_DIM; i += blockDim.x) { - out_base[t * out_token_stride + i] = sOut[i]; - } - __syncthreads(); - } - - // Write final state DIRECTLY (no transpose needed - same layout) - for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim.x) { - state_dst[i] = __half2float(state_smem[i]); - } -} - -#endif // !defined(GGML_USE_HIP) - -// Blackwell kernel (SM 12.0+): Full 64KB state in shared memory -#if !defined(GGML_USE_HIP) - -template -__global__ __launch_bounds__(256, 1) // 256 threads, 1 block per SM for max shared mem -void delta_net_blackwell_f32( - const float * __restrict__ q, - const float * __restrict__ k, - const float * __restrict__ v, - const float * __restrict__ g, - const float * __restrict__ beta_in, - const float * __restrict__ state_in, - float * __restrict__ dst, - const int64_t n_tokens, - const int64_t n_heads, - const int64_t n_seqs, - const int64_t output_offset, - const float eps) -{ - static_assert(HEAD_DIM == 128, "Blackwell kernel optimized for HEAD_DIM=128"); - - // One block per (batch, head) - NO column splitting! - const int batch_idx = blockIdx.x / n_heads; - const int head_idx = blockIdx.x % n_heads; - const int tid = threadIdx.x; - const int warp_id = tid / WARP_SIZE; - const int lane_id = tid % WARP_SIZE; - constexpr int NUM_WARPS = 8; // 256 / 32 - - // Strides (column-major) - const int64_t qkv_stride_token = HEAD_DIM; - const int64_t qkv_stride_head = HEAD_DIM * n_tokens; - const int64_t qkv_stride_batch = HEAD_DIM * n_tokens * n_heads; - const int64_t g_stride_head = n_tokens; - const int64_t g_stride_batch = n_tokens * n_heads; - const int64_t state_head_offset = head_idx * HEAD_DIM * HEAD_DIM; - const int64_t state_batch_stride = HEAD_DIM * HEAD_DIM * n_heads; - - // Pointers - const float * q_ptr = q + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; - const float * k_ptr = k + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; - const float * v_ptr = v + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; - const float * g_ptr = g + batch_idx * g_stride_batch + head_idx * g_stride_head; - const float * beta_ptr = beta_in + batch_idx * g_stride_batch + head_idx * g_stride_head; - const float * state_src = state_in + batch_idx * state_batch_stride + state_head_offset; - float * out_base = dst + batch_idx * (HEAD_DIM * n_heads * n_tokens) + head_idx * HEAD_DIM; - const int64_t out_token_stride = HEAD_DIM * n_heads; - float * state_dst = dst + output_offset + batch_idx * state_batch_stride + state_head_offset; - - // Shared memory: 64KB state + 4.5KB vectors + scratch - extern __shared__ char smem_raw[]; - float * state_smem = (float *)smem_raw; - float * sQ = (float *)(smem_raw + HEAD_DIM * HEAD_DIM * sizeof(float)); - float * sK = sQ + HEAD_DIM; - float * sV = sK + HEAD_DIM; - float * sKBeta = sV + HEAD_DIM; - float * sVBeta = sKBeta + HEAD_DIM; - float * sKCumdecay = sVBeta + HEAD_DIM; - float * sVPrime = sKCumdecay + HEAD_DIM; - float * sVNew = sVPrime + HEAD_DIM; - float * sOut = sVNew + HEAD_DIM; - - float * warp_scratch = sOut + HEAD_DIM; - __shared__ float shared_decay, shared_beta, shared_attn_score, shared_q_norm, shared_k_norm; - const float scale = rsqrtf((float)HEAD_DIM); - - // Load state (transposed for coalesced access) - #pragma unroll 8 - for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim.x) { - int col = i / HEAD_DIM, row = i % HEAD_DIM; - state_smem[row + col * HEAD_DIM] = state_src[col + row * HEAD_DIM]; - } - __syncthreads(); - - for (int64_t t = 0; t < n_tokens; t++) { - // Load Q, K, V and compute norms - float q_sq_local = 0.0f, k_sq_local = 0.0f; - #pragma unroll 2 - for (int i = tid; i < HEAD_DIM; i += blockDim.x) { - float q_val = q_ptr[t * qkv_stride_token + i]; - float k_val = k_ptr[t * qkv_stride_token + i]; - sQ[i] = q_val; - sK[i] = k_val; - sV[i] = v_ptr[t * qkv_stride_token + i]; - q_sq_local += q_val * q_val; - k_sq_local += k_val * k_val; - } - - #pragma unroll - for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { - q_sq_local += __shfl_xor_sync(0xffffffff, q_sq_local, offset); - k_sq_local += __shfl_xor_sync(0xffffffff, k_sq_local, offset); - } - - if (lane_id == 0) { - warp_scratch[warp_id * 2] = q_sq_local; - warp_scratch[warp_id * 2 + 1] = k_sq_local; - } - __syncthreads(); - - if (tid == 0) { - float total_q = 0.0f, total_k = 0.0f; - #pragma unroll - for (int w = 0; w < NUM_WARPS; w++) { - total_q += warp_scratch[w * 2]; - total_k += warp_scratch[w * 2 + 1]; - } - shared_q_norm = rsqrtf(total_q + eps); - shared_k_norm = rsqrtf(total_k + eps); - shared_decay = expf(fminf(g_ptr[t], 50.0f)); - shared_beta = sigmoid_f(beta_ptr[t]); - } - __syncthreads(); - - float q_norm = shared_q_norm; - float k_norm = shared_k_norm; - float decay = shared_decay; - float beta_val = shared_beta; - - // Normalize and prepare vectors - #pragma unroll 2 - for (int i = tid; i < HEAD_DIM; i += blockDim.x) { - sQ[i] = sQ[i] * q_norm * scale; - sK[i] = sK[i] * k_norm; - sKBeta[i] = sK[i]; - sVBeta[i] = sV[i] * beta_val; - sKCumdecay[i] = sK[i] * beta_val * decay; - } - __syncthreads(); - - // v_prime = state @ k_cumdecay - for (int col = warp_id; col < HEAD_DIM; col += NUM_WARPS) { - float sum = 0.0f; - #pragma unroll 4 - for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) { - sum += state_smem[row + col * HEAD_DIM] * sKCumdecay[row]; - } - #pragma unroll - for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { - sum += __shfl_xor_sync(0xffffffff, sum, offset); - } - if (lane_id == 0) sVPrime[col] = sum; - } - __syncthreads(); - - // v_new = v_beta - v_prime - #pragma unroll 2 - for (int i = tid; i < HEAD_DIM; i += blockDim.x) { - sVNew[i] = sVBeta[i] - sVPrime[i]; - } - __syncthreads(); - - // attn_score = dot(K, Q) - if (warp_id == 0) { - float sum = 0.0f; - #pragma unroll 4 - for (int i = lane_id; i < HEAD_DIM; i += WARP_SIZE) { - sum += sK[i] * sQ[i]; - } - #pragma unroll - for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { - sum += __shfl_xor_sync(0xffffffff, sum, offset); - } - if (lane_id == 0) shared_attn_score = sum; - } - __syncthreads(); - - float attn_score = shared_attn_score; - - // output = (state @ q*decay) + v_new * attn_score - for (int col = warp_id; col < HEAD_DIM; col += NUM_WARPS) { - float sum = 0.0f; - #pragma unroll 4 - for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) { - sum += state_smem[row + col * HEAD_DIM] * sQ[row] * decay; - } - #pragma unroll - for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { - sum += __shfl_xor_sync(0xffffffff, sum, offset); - } - if (lane_id == 0) sOut[col] = sum + sVNew[col] * attn_score; - } - __syncthreads(); - - // Update state: state_new = decay * state + outer(v_new, k) - float safe_decay = (isnan(decay) || isinf(decay)) ? 1.0f : decay; - for (int col = tid; col < HEAD_DIM; col += blockDim.x) { - float v_col = sVNew[col]; - for (int row = 0; row < HEAD_DIM; row++) { - float old_state = state_smem[row + col * HEAD_DIM]; - float new_state = safe_decay * old_state + v_col * sKBeta[row]; - state_smem[row + col * HEAD_DIM] = fminf(fmaxf(new_state, -1e6f), 1e6f); - } - } - __syncthreads(); - - // Write output - #pragma unroll 2 - for (int i = tid; i < HEAD_DIM; i += blockDim.x) { - out_base[t * out_token_stride + i] = sOut[i]; - } - __syncthreads(); - } - - // Write final state (transpose back) - #pragma unroll 8 - for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim.x) { - int col = i / HEAD_DIM, row = i % HEAD_DIM; - state_dst[col + row * HEAD_DIM] = state_smem[row + col * HEAD_DIM]; - } -} - -// Blackwell V2: Bank-conflict-free with padded layout (128→132) -template -__global__ __launch_bounds__(256, 1) -void delta_net_blackwell_optimized_f32( - const float * __restrict__ q, - const float * __restrict__ k, - const float * __restrict__ v, - const float * __restrict__ g, - const float * __restrict__ beta_in, - const float * __restrict__ state_in, - float * __restrict__ dst, - const int64_t n_tokens, - const int64_t n_heads, - const int64_t n_seqs, - const int64_t output_offset, - const float eps) -{ - static_assert(HEAD_DIM == 128, "Optimized kernel for HEAD_DIM=128"); - constexpr int PADDED_DIM = HEAD_DIM + 4; // Bank conflict elimination - - const int batch_idx = blockIdx.x / n_heads; - const int head_idx = blockIdx.x % n_heads; - const int tid = threadIdx.x; - const int warp_id = tid / WARP_SIZE; - const int lane_id = tid % WARP_SIZE; - constexpr int NUM_WARPS = 8; - - const int64_t qkv_stride_token = HEAD_DIM; - const int64_t qkv_stride_head = HEAD_DIM * n_tokens; - const int64_t qkv_stride_batch = HEAD_DIM * n_tokens * n_heads; - const int64_t g_stride_head = n_tokens; - const int64_t g_stride_batch = n_tokens * n_heads; - const int64_t state_head_offset = head_idx * HEAD_DIM * HEAD_DIM; - const int64_t state_batch_stride = HEAD_DIM * HEAD_DIM * n_heads; - - const float * q_ptr = q + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; - const float * k_ptr = k + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; - const float * v_ptr = v + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; - const float * g_ptr = g + batch_idx * g_stride_batch + head_idx * g_stride_head; - const float * beta_ptr = beta_in + batch_idx * g_stride_batch + head_idx * g_stride_head; - const float * state_src = state_in + batch_idx * state_batch_stride + state_head_offset; - float * out_base = dst + batch_idx * (HEAD_DIM * n_heads * n_tokens) + head_idx * HEAD_DIM; - const int64_t out_token_stride = HEAD_DIM * n_heads; - float * state_dst = dst + output_offset + batch_idx * state_batch_stride + state_head_offset; - - // Shared memory: 67.5KB padded state + 4.5KB vectors - extern __shared__ char smem_raw[]; - float * state_smem = (float *)smem_raw; - float * sQ = (float *)(smem_raw + HEAD_DIM * PADDED_DIM * sizeof(float)); - float * sK = sQ + HEAD_DIM; - float * sV = sK + HEAD_DIM; - float * sKBeta = sV + HEAD_DIM; - float * sVBeta = sKBeta + HEAD_DIM; - float * sKCumdecay = sVBeta + HEAD_DIM; - float * sVPrime = sKCumdecay + HEAD_DIM; - float * sVNew = sVPrime + HEAD_DIM; - float * sOut = sVNew + HEAD_DIM; - float * warp_scratch = sOut + HEAD_DIM; - __shared__ float shared_decay, shared_beta, shared_attn_score, shared_q_norm, shared_k_norm; - const float scale = rsqrtf((float)HEAD_DIM); - - // Load state with padding - #pragma unroll 8 - for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim.x) { - int col = i / HEAD_DIM, row = i % HEAD_DIM; - state_smem[row + col * PADDED_DIM] = state_src[row + col * HEAD_DIM]; - } - __syncthreads(); - - for (int64_t t = 0; t < n_tokens; t++) { - // Load Q, K, V (vectorized) - float q_sq_local = 0.0f, k_sq_local = 0.0f; - const float4 * q_ptr_v = (const float4 *)(q_ptr + t * qkv_stride_token); - const float4 * k_ptr_v = (const float4 *)(k_ptr + t * qkv_stride_token); - const float4 * v_ptr_v = (const float4 *)(v_ptr + t * qkv_stride_token); - - #pragma unroll 2 - for (int i = tid; i < HEAD_DIM / 4; i += blockDim.x) { - float4 q_val = q_ptr_v[i]; - float4 k_val = k_ptr_v[i]; - float4 v_val = v_ptr_v[i]; - int base = i * 4; - sQ[base] = q_val.x; sQ[base+1] = q_val.y; sQ[base+2] = q_val.z; sQ[base+3] = q_val.w; - sK[base] = k_val.x; sK[base+1] = k_val.y; sK[base+2] = k_val.z; sK[base+3] = k_val.w; - sV[base] = v_val.x; sV[base+1] = v_val.y; sV[base+2] = v_val.z; sV[base+3] = v_val.w; - q_sq_local += q_val.x*q_val.x + q_val.y*q_val.y + q_val.z*q_val.z + q_val.w*q_val.w; - k_sq_local += k_val.x*k_val.x + k_val.y*k_val.y + k_val.z*k_val.z + k_val.w*k_val.w; - } - - // Warp reduction for norms - #pragma unroll - for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { - q_sq_local += __shfl_xor_sync(0xffffffff, q_sq_local, offset); - k_sq_local += __shfl_xor_sync(0xffffffff, k_sq_local, offset); - } - - // Cross-warp reduction using shared memory - if (lane_id == 0) { - warp_scratch[warp_id * 2] = q_sq_local; - warp_scratch[warp_id * 2 + 1] = k_sq_local; - } - __syncthreads(); - - if (tid == 0) { - float total_q = 0.0f, total_k = 0.0f; - #pragma unroll - for (int w = 0; w < NUM_WARPS; w++) { - total_q += warp_scratch[w * 2]; - total_k += warp_scratch[w * 2 + 1]; - } - shared_q_norm = rsqrtf(total_q + eps); - shared_k_norm = rsqrtf(total_k + eps); - shared_decay = expf(fminf(g_ptr[t], 50.0f)); - shared_beta = sigmoid_f(beta_ptr[t]); - } - __syncthreads(); - - float q_norm = shared_q_norm, k_norm = shared_k_norm; - float decay = shared_decay, beta_val = shared_beta; - - // Normalize vectors - #pragma unroll 2 - for (int i = tid; i < HEAD_DIM; i += blockDim.x) { - sQ[i] = sQ[i] * q_norm * scale; - sK[i] = sK[i] * k_norm; - sKBeta[i] = sK[i]; - sVBeta[i] = sV[i] * beta_val; - sKCumdecay[i] = sK[i] * beta_val * decay; - } - __syncthreads(); - - // v_prime = state @ k_cumdecay - for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) { - float sum = 0.0f; - #pragma unroll 4 - for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) { - sum += state_smem[row_out + col * PADDED_DIM] * sKCumdecay[col]; - } - #pragma unroll - for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { - sum += __shfl_xor_sync(0xffffffff, sum, offset); - } - if (lane_id == 0) sVPrime[row_out] = sum; - } - __syncthreads(); - - // v_new = v_beta - v_prime - #pragma unroll 2 - for (int i = tid; i < HEAD_DIM; i += blockDim.x) { - sVNew[i] = sVBeta[i] - sVPrime[i]; - } - __syncthreads(); - - // attn_score = dot(K, Q) - if (warp_id == 0) { - float sum = 0.0f; - #pragma unroll 4 - for (int i = lane_id; i < HEAD_DIM; i += WARP_SIZE) { - sum += sK[i] * sQ[i]; - } - #pragma unroll - for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { - sum += __shfl_xor_sync(0xffffffff, sum, offset); - } - if (lane_id == 0) shared_attn_score = sum; - } - __syncthreads(); - - float attn_score = shared_attn_score; - - // output = (state @ q*decay) + v_new * attn_score - for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) { - float sum = 0.0f; - #pragma unroll 4 - for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) { - sum += state_smem[row_out + col * PADDED_DIM] * sQ[col] * decay; - } - #pragma unroll - for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { - sum += __shfl_xor_sync(0xffffffff, sum, offset); - } - if (lane_id == 0) sOut[row_out] = sum + sVNew[row_out] * attn_score; - } - __syncthreads(); - - // State update - #pragma unroll 4 - for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim.x) { - int col = i / HEAD_DIM, row = i % HEAD_DIM; - float old_state = state_smem[row + col * PADDED_DIM]; - float new_state = decay * old_state + sKBeta[row] * sVNew[col]; - state_smem[row + col * PADDED_DIM] = fminf(fmaxf(new_state, -1e6f), 1e6f); - } - __syncthreads(); - - // Write output (vectorized) - float4 * out_ptr_v = (float4 *)(out_base + t * out_token_stride); - #pragma unroll 2 - for (int i = tid; i < HEAD_DIM / 4; i += blockDim.x) { - int base = i * 4; - float4 out_val = {sOut[base], sOut[base+1], sOut[base+2], sOut[base+3]}; - out_ptr_v[i] = out_val; - } - __syncthreads(); - } - - // Write final state (remove padding) - #pragma unroll 8 - for (int i = tid; i < HEAD_DIM * HEAD_DIM; i += blockDim.x) { - int col = i / HEAD_DIM, row = i % HEAD_DIM; - state_dst[row + col * HEAD_DIM] = state_smem[row + col * PADDED_DIM]; - } -} - -#endif // !defined(GGML_USE_HIP) - -// Multi-block column-parallel kernel (pre-Blackwell fallback) -// Each block handles COLS_PER_BLOCK columns of the 128x128 state -// With COLS_PER_BLOCK=16: 128/16 = 8 blocks per head, 16 heads = 128 blocks total -// State tile per block: 128 rows × 16 cols = 2048 floats = 8KB (fits in shared memory!) -template -__global__ void delta_net_multiblock_f32( - const float * __restrict__ q, // [HEAD_DIM, n_tokens, n_heads, n_seqs] - const float * __restrict__ k, // [HEAD_DIM, n_tokens, n_heads, n_seqs] - const float * __restrict__ v, // [HEAD_DIM, n_tokens, n_heads, n_seqs] - const float * __restrict__ g, // [n_tokens, 1, n_heads, n_seqs] - const float * __restrict__ beta_in, // [1, n_tokens, n_heads, n_seqs] - const float * __restrict__ state_in, // [HEAD_DIM, HEAD_DIM*n_heads, 1, n_seqs] - float * __restrict__ dst, // output + new_state concatenated - const int64_t n_tokens, - const int64_t n_heads, - const int64_t n_seqs, - const int64_t output_offset, - const float eps) -{ - static_assert(HEAD_DIM % COLS_PER_BLOCK == 0, "HEAD_DIM must be divisible by COLS_PER_BLOCK"); - constexpr int NUM_COL_GROUPS = HEAD_DIM / COLS_PER_BLOCK; - - // Decode block index: (batch_idx, head_idx, col_group) - const int blocks_per_seq = n_heads * NUM_COL_GROUPS; - const int batch_idx = blockIdx.x / blocks_per_seq; - const int remaining = blockIdx.x % blocks_per_seq; - const int head_idx = remaining / NUM_COL_GROUPS; - const int col_group = remaining % NUM_COL_GROUPS; - const int col_start = col_group * COLS_PER_BLOCK; - - const int tid = threadIdx.x; - const int warp_id = tid / WARP_SIZE; - const int lane_id = tid % WARP_SIZE; - constexpr int NUM_WARPS = 8; - - // Strides (column-major) - const int64_t qkv_stride_token = HEAD_DIM; - const int64_t qkv_stride_head = HEAD_DIM * n_tokens; - const int64_t qkv_stride_batch = HEAD_DIM * n_tokens * n_heads; - const int64_t g_stride_head = n_tokens; - const int64_t g_stride_batch = n_tokens * n_heads; - const int64_t state_head_offset = head_idx * HEAD_DIM * HEAD_DIM; - const int64_t state_batch_stride = HEAD_DIM * HEAD_DIM * n_heads; - - // Pointers - const float * q_ptr = q + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; - const float * k_ptr = k + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; - const float * v_ptr = v + batch_idx * qkv_stride_batch + head_idx * qkv_stride_head; - const float * g_ptr = g + batch_idx * g_stride_batch + head_idx * g_stride_head; - const float * beta_ptr = beta_in + batch_idx * g_stride_batch + head_idx * g_stride_head; - const float * state_src = state_in + batch_idx * state_batch_stride + state_head_offset; - - float * out_base = dst + batch_idx * (HEAD_DIM * n_heads * n_tokens) + head_idx * HEAD_DIM; - const int64_t out_token_stride = HEAD_DIM * n_heads; - float * state_dst = dst + output_offset + batch_idx * state_batch_stride + state_head_offset; - - // Shared memory layout: - // - State tile: HEAD_DIM × COLS_PER_BLOCK = 128 × 16 = 2048 floats = 8KB - // - Full vectors: K, KCumdecay, Q (need all HEAD_DIM elements) = 3 × 128 = 1.5KB - // - Local vectors: V, VBeta, VPrime, VNew, Out (only COLS_PER_BLOCK) = 5 × 16 = 320 bytes - // - Norms: 2 floats - // Total: ~10KB (excellent for occupancy!) - extern __shared__ float smem[]; - - // State tile in shared memory: state_tile[row + local_col * HEAD_DIM] - // local_col ∈ [0, COLS_PER_BLOCK), global_col = col_start + local_col - float * state_tile = smem; // HEAD_DIM * COLS_PER_BLOCK - - // Full vectors (need all HEAD_DIM for matrix-vector and dot products) - float * sK = state_tile + HEAD_DIM * COLS_PER_BLOCK; // HEAD_DIM - float * sKCumdecay = sK + HEAD_DIM; // HEAD_DIM - float * sQ = sKCumdecay + HEAD_DIM; // HEAD_DIM - - // Local vectors (only need COLS_PER_BLOCK elements) - float * sV = sQ + HEAD_DIM; // COLS_PER_BLOCK - float * sVBeta = sV + COLS_PER_BLOCK; // COLS_PER_BLOCK - float * sVPrime = sVBeta + COLS_PER_BLOCK; // COLS_PER_BLOCK - float * sVNew = sVPrime + COLS_PER_BLOCK; // COLS_PER_BLOCK - float * sOut = sVNew + COLS_PER_BLOCK; // COLS_PER_BLOCK - float * sNorm = sOut + COLS_PER_BLOCK; // 2 - - __shared__ float shared_decay, shared_beta, shared_attn_score; - - const float scale = rsqrtf((float)HEAD_DIM); - - // Load initial state tile from global to shared - // state_tile[row + local_col * HEAD_DIM] = state[row, col_start + local_col] - for (int i = tid; i < HEAD_DIM * COLS_PER_BLOCK; i += blockDim.x) { - int row = i % HEAD_DIM; - int local_col = i / HEAD_DIM; - int global_col = col_start + local_col; - state_tile[row + local_col * HEAD_DIM] = state_src[row + global_col * HEAD_DIM]; - } - __syncthreads(); - - // Process each token - for (int64_t t = 0; t < n_tokens; t++) { - // Reset norms - if (tid < 2) { - sNorm[tid] = 0.0f; - } - __syncthreads(); - - // 1. Load full K, Q (all HEAD_DIM elements - needed for matrix-vector and attn_score) - float q_sq_local = 0.0f, k_sq_local = 0.0f; - for (int i = tid; i < HEAD_DIM; i += blockDim.x) { - float q_val = q_ptr[t * qkv_stride_token + i]; - float k_val = k_ptr[t * qkv_stride_token + i]; - sQ[i] = q_val; - sK[i] = k_val; - q_sq_local += q_val * q_val; - k_sq_local += k_val * k_val; - } - - // Load V for our columns only - for (int i = tid; i < COLS_PER_BLOCK; i += blockDim.x) { - sV[i] = v_ptr[t * qkv_stride_token + col_start + i]; - } - - // Warp reduction for norms - #pragma unroll - for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { - q_sq_local += __shfl_xor_sync(0xffffffff, q_sq_local, offset); - k_sq_local += __shfl_xor_sync(0xffffffff, k_sq_local, offset); - } - if (lane_id == 0) { - atomicAdd(&sNorm[0], q_sq_local); - atomicAdd(&sNorm[1], k_sq_local); - } - __syncthreads(); - - float q_norm = rsqrtf(sNorm[0] + eps); - float k_norm = rsqrtf(sNorm[1] + eps); - - // 2. Load g, beta and normalize vectors - if (tid == 0) { - shared_decay = expf(fminf(g_ptr[t], 50.0f)); // Clamp g to prevent overflow - shared_beta = sigmoid_f(beta_ptr[t]); - } - __syncthreads(); - - float decay = shared_decay; - float beta_val = shared_beta; - - // Normalize and compute KCumdecay - for (int i = tid; i < HEAD_DIM; i += blockDim.x) { - sQ[i] = sQ[i] * q_norm * scale; - sK[i] = sK[i] * k_norm; - sKCumdecay[i] = sK[i] * beta_val * decay; - } - - // Compute VBeta for our columns - for (int i = tid; i < COLS_PER_BLOCK; i += blockDim.x) { - sVBeta[i] = sV[i] * beta_val; - } - __syncthreads(); - - // 3. Compute v_prime for our columns: v_prime[local_col] = sum_row(state_tile[row, local_col] * k_cumdecay[row]) - // Each warp handles one local column - for (int local_col = warp_id; local_col < COLS_PER_BLOCK; local_col += NUM_WARPS) { - float sum = 0.0f; - #pragma unroll 4 - for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) { - sum += state_tile[row + local_col * HEAD_DIM] * sKCumdecay[row]; - } - // Warp reduction - #pragma unroll - for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { - sum += __shfl_xor_sync(0xffffffff, sum, offset); - } - if (lane_id == 0) { - sVPrime[local_col] = sum; - } - } - __syncthreads(); - - // 4. Compute v_new for our columns - for (int i = tid; i < COLS_PER_BLOCK; i += blockDim.x) { - sVNew[i] = sVBeta[i] - sVPrime[i]; - } - __syncthreads(); - - // 5. Compute attn_score = dot(k, q) - all blocks compute this redundantly - if (warp_id == 0) { - float sum = 0.0f; - #pragma unroll 4 - for (int i = lane_id; i < HEAD_DIM; i += WARP_SIZE) { - sum += sK[i] * sQ[i]; - } - #pragma unroll - for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { - sum += __shfl_xor_sync(0xffffffff, sum, offset); - } - if (lane_id == 0) { - shared_attn_score = sum; - } - } - __syncthreads(); - - // 6. Compute output for our columns: out[local_col] = attn_inter + v_attn - // attn_inter[local_col] = sum_row(state_tile[row, local_col] * q_scaled[row]) - for (int local_col = warp_id; local_col < COLS_PER_BLOCK; local_col += NUM_WARPS) { - float sum = 0.0f; - #pragma unroll 4 - for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) { - sum += state_tile[row + local_col * HEAD_DIM] * sQ[row] * decay; - } - #pragma unroll - for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) { - sum += __shfl_xor_sync(0xffffffff, sum, offset); - } - if (lane_id == 0) { - float v_attn = sVNew[local_col] * shared_attn_score; - sOut[local_col] = sum + v_attn; - } - } - __syncthreads(); - - // 7. Update state tile: state_new[row, local_col] = decay * state[row, local_col] + v_new[row] * k[local_col] - // Fixed: outer product orientation matches decomposed: state[v_idx, k_idx] += v_new[v_idx] * k[k_idx] - float safe_decay = fminf(fmaxf(decay, 0.0f), 10.0f); - for (int i = tid; i < HEAD_DIM * COLS_PER_BLOCK; i += blockDim.x) { - int row = i % HEAD_DIM; - int local_col = i / HEAD_DIM; - - float state_val = state_tile[row + local_col * HEAD_DIM]; - // Fix: v_new[row=v_idx] * k[local_col=k_idx] to match decomposed - float new_val = safe_decay * state_val + sVNew[row] * sK[local_col]; - new_val = fminf(fmaxf(new_val, -1e6f), 1e6f); - state_tile[row + local_col * HEAD_DIM] = new_val; - } - __syncthreads(); - - // 8. Write output for our columns - for (int i = tid; i < COLS_PER_BLOCK; i += blockDim.x) { - int global_col = col_start + i; - out_base[t * out_token_stride + global_col] = sOut[i]; - } - __syncthreads(); - } - - // Write final state tile back to global - for (int i = tid; i < HEAD_DIM * COLS_PER_BLOCK; i += blockDim.x) { - int row = i % HEAD_DIM; - int local_col = i / HEAD_DIM; - int global_col = col_start + local_col; - state_dst[row + global_col * HEAD_DIM] = state_tile[row + local_col * HEAD_DIM]; - } -} - -enum delta_net_opt_mode : int { - DELTA_NET_OPT_DEFAULT = 0, // keep current dispatch - DELTA_NET_OPT_FP16 = 1, // pre-Blackwell: fp16 recurrent kernel (head_dim=128) - DELTA_NET_OPT_MULTIBLOCK = 2, // pre-Blackwell: multiblock kernel (head_dim=128) - DELTA_NET_OPT_BW_OPT = 3, // Blackwell: padded/bank-conflict-reduced kernel - DELTA_NET_OPT_AUTO = 4, // arch-aware: multiblock (pre-BW), bw-opt (BW) -}; - -static int delta_net_get_opt_mode() { - static const int mode = []() -> int { - const char * env = std::getenv("GGML_CUDA_DELTA_NET_OPT"); - if (env == nullptr || env[0] == '\0') { - return DELTA_NET_OPT_DEFAULT; - } - - if (!strcmp(env, "auto") || !strcmp(env, "AUTO")) { - return DELTA_NET_OPT_AUTO; - } - if (!strcmp(env, "fp16")) { - return DELTA_NET_OPT_FP16; - } - if (!strcmp(env, "multiblock")) { - return DELTA_NET_OPT_MULTIBLOCK; - } - if (!strcmp(env, "blackwell-opt")) { - return DELTA_NET_OPT_BW_OPT; - } - - const int parsed = atoi(env); - if (parsed >= DELTA_NET_OPT_DEFAULT && parsed <= DELTA_NET_OPT_AUTO) { - return parsed; - } - - return DELTA_NET_OPT_DEFAULT; - }(); - - return mode; -} - -// Dispatch function -// device_id and cc (compute capability) are passed from caller to avoid CUDA runtime API calls static void delta_net_f32_cuda( const float * q, const float * k, @@ -1426,220 +412,34 @@ static void delta_net_f32_cuda( const float eps, const int device_id, const int cc, // compute capability (e.g., 890 for SM 8.9, 1200 for SM 12.0) - cudaStream_t stream) -{ + cudaStream_t stream) { GGML_UNUSED(device_id); + GGML_UNUSED(cc); const int64_t output_offset = head_dim * n_tokens * n_heads * n_seqs; // One block per (batch, head) pair const int num_blocks = n_seqs * n_heads; - const int threads_per_block = 256; + constexpr int threads_per_block = 256; // Shared memory: 9 * head_dim (for Q, K, V, KBeta, VBeta, Out, KCumdecay, VPrime, VNew) // Plus 6 floats for Norm[2], g_val, beta_val, decay, attn_score const size_t smem_size = (9 * head_dim + 6) * sizeof(float); - const int opt_mode = delta_net_get_opt_mode(); // Use templated kernel for common head dimensions, generic for others if (head_dim == 64) { - delta_net_recurrent_f32<64><<>>( - q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps); + delta_net_recurrent_f32<64, threads_per_block><<>>( + q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps); } else if (head_dim == 128) { -#if !defined(GGML_USE_HIP) - // Check for Blackwell (SM 12.0+) which has 228KB shared memory - // cc is in format MAJOR*100 + MINOR*10 (e.g., 890 for 8.9, 1200 for 12.0) - const int sm_major = cc / 100; - - constexpr size_t blackwell_state_bytes = 128 * 128 * sizeof(float); // 64 KB - constexpr size_t blackwell_vector_bytes = 9 * 128 * sizeof(float); // 4.5 KB - constexpr size_t blackwell_warp_scratch_bytes = 16 * sizeof(float); // 64 B - constexpr size_t blackwell_smem_size = - blackwell_state_bytes + blackwell_vector_bytes + blackwell_warp_scratch_bytes; - static_assert(blackwell_smem_size == 70208, "Shared memory size mismatch"); - - constexpr size_t blackwell_opt_state_bytes = 128 * 132 * sizeof(float); // padded 128x132 - constexpr size_t blackwell_opt_vector_bytes = 9 * 128 * sizeof(float); - constexpr size_t blackwell_opt_warp_scratch_bytes = 16 * sizeof(float); - constexpr size_t blackwell_opt_smem_size = - blackwell_opt_state_bytes + blackwell_opt_vector_bytes + blackwell_opt_warp_scratch_bytes; - static_assert(blackwell_opt_smem_size == 72256, "Optimized shared memory size mismatch"); - - constexpr int multiblock_cols = 16; - constexpr int multiblock_groups = 128 / multiblock_cols; - constexpr size_t multiblock_smem_floats = - 128 * multiblock_cols + 3 * 128 + 5 * multiblock_cols + 2; - constexpr size_t multiblock_smem_size = multiblock_smem_floats * sizeof(float); - static_assert(multiblock_smem_size == 10056, "Multiblock shared memory size mismatch"); - - constexpr size_t fp16_state_bytes = 128 * 128 * sizeof(half); - constexpr size_t fp16_half_vec_bytes = 3 * 128 * sizeof(half); - constexpr size_t fp16_float_vec_bytes = 6 * 128 * sizeof(float); - constexpr size_t fp16_scalar_bytes = 2 * sizeof(float); - constexpr size_t fp16_smem_size = - fp16_state_bytes + fp16_half_vec_bytes + fp16_float_vec_bytes + fp16_scalar_bytes; - static_assert(fp16_smem_size == 36616, "FP16 shared memory size mismatch"); - - // Keep "auto" conservative on Blackwell (baseline kernel remains default there). - // Explicit modes can still force a different kernel for experiments. - const bool use_bw_opt = - sm_major >= 12 && opt_mode == DELTA_NET_OPT_BW_OPT; - const bool use_multiblock = - opt_mode == DELTA_NET_OPT_MULTIBLOCK || - (sm_major < 12 && opt_mode == DELTA_NET_OPT_AUTO); - const bool use_fp16 = opt_mode == DELTA_NET_OPT_FP16; - - if (use_bw_opt) { - const int blackwell_num_blocks = n_seqs * n_heads; - const int blackwell_threads = 256; - - CUDA_CHECK(cudaFuncSetAttribute( - delta_net_blackwell_optimized_f32<128>, - cudaFuncAttributeMaxDynamicSharedMemorySize, - blackwell_opt_smem_size)); - - delta_net_blackwell_optimized_f32<128><<>>( - q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps); - } else if (sm_major >= 12) { - // Blackwell path: single block per head with FULL state in shared memory - const int blackwell_num_blocks = n_seqs * n_heads; - const int blackwell_threads = 256; - - // A/B comparison mode (set GGML_CUDA_DELTA_NET_AB=1) - static const bool ab_mode = []() { - const char* env = std::getenv("GGML_CUDA_DELTA_NET_AB"); - if (env != nullptr) { - fprintf(stderr, "[DELTA_NET] A/B comparison mode ENABLED\n"); - return true; - } - return false; - }(); - - if (ab_mode) { - // A/B mode: run both kernels and compare outputs - const int64_t total_output_size = output_offset + head_dim * head_dim * n_heads * n_seqs; - - // Allocate temp buffer for recurrent kernel output - float * temp_dst = nullptr; - CUDA_CHECK(cudaMallocAsync(&temp_dst, total_output_size * sizeof(float), stream)); - - // Run recurrent kernel (reference) to temp buffer - delta_net_recurrent_f32<128><<>>( - q, k, v, g, beta, state_in, temp_dst, n_tokens, n_heads, n_seqs, output_offset, eps); - - // Request extended shared memory for Blackwell - CUDA_CHECK(cudaFuncSetAttribute( - delta_net_blackwell_f32<128>, - cudaFuncAttributeMaxDynamicSharedMemorySize, - blackwell_smem_size)); - - // Run Blackwell kernel to dst - delta_net_blackwell_f32<128><<>>( - q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps); - - // Sync to ensure both kernels complete - CUDA_CHECK(cudaStreamSynchronize(stream)); - - // Copy results back to host for comparison - const int64_t output_elements = head_dim * n_tokens * n_heads * n_seqs; - const int64_t state_elements = head_dim * head_dim * n_heads * n_seqs; - - std::vector ref_output(output_elements); - std::vector ref_state(state_elements); - std::vector bw_output(output_elements); - std::vector bw_state(state_elements); - - CUDA_CHECK(cudaMemcpy(ref_output.data(), temp_dst, output_elements * sizeof(float), cudaMemcpyDeviceToHost)); - CUDA_CHECK(cudaMemcpy(ref_state.data(), temp_dst + output_offset, state_elements * sizeof(float), cudaMemcpyDeviceToHost)); - CUDA_CHECK(cudaMemcpy(bw_output.data(), dst, output_elements * sizeof(float), cudaMemcpyDeviceToHost)); - CUDA_CHECK(cudaMemcpy(bw_state.data(), dst + output_offset, state_elements * sizeof(float), cudaMemcpyDeviceToHost)); - - // Compare outputs - float max_out_diff = 0.0f; - int64_t max_out_idx = 0; - for (int64_t i = 0; i < output_elements; i++) { - float diff = fabsf(ref_output[i] - bw_output[i]); - if (diff > max_out_diff) { - max_out_diff = diff; - max_out_idx = i; - } - } - - // Compare states - float max_state_diff = 0.0f; - int64_t max_state_idx = 0; - for (int64_t i = 0; i < state_elements; i++) { - float diff = fabsf(ref_state[i] - bw_state[i]); - if (diff > max_state_diff) { - max_state_diff = diff; - max_state_idx = i; - } - } - - // Report results - static int ab_call_count = 0; - ab_call_count++; - fprintf(stderr, "[DELTA_NET A/B #%d] n_tokens=%lld output_diff=%e (idx=%lld ref=%e bw=%e) state_diff=%e (idx=%lld ref=%e bw=%e)\n", - ab_call_count, - (long long)n_tokens, - max_out_diff, (long long)max_out_idx, ref_output[max_out_idx], bw_output[max_out_idx], - max_state_diff, (long long)max_state_idx, ref_state[max_state_idx], bw_state[max_state_idx]); - - // Report first 4 output values for head 0 - if (ab_call_count <= 10) { - fprintf(stderr, " ref_out[0:3]=[%e,%e,%e,%e] bw_out[0:3]=[%e,%e,%e,%e]\n", - ref_output[0], ref_output[1], ref_output[2], ref_output[3], - bw_output[0], bw_output[1], bw_output[2], bw_output[3]); - fprintf(stderr, " ref_state[0,1,128,129]=[%e,%e,%e,%e] bw_state=[%e,%e,%e,%e]\n", - ref_state[0], ref_state[1], ref_state[128], ref_state[129], - bw_state[0], bw_state[1], bw_state[128], bw_state[129]); - } - - CUDA_CHECK(cudaFreeAsync(temp_dst, stream)); - } else { - // Normal mode: just run Blackwell kernel - // Request extended shared memory for Blackwell - CUDA_CHECK(cudaFuncSetAttribute( - delta_net_blackwell_f32<128>, - cudaFuncAttributeMaxDynamicSharedMemorySize, - blackwell_smem_size)); - - delta_net_blackwell_f32<128><<>>( - q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps); - } - } else if (use_multiblock) { - const int multiblock_num_blocks = n_seqs * n_heads * multiblock_groups; - delta_net_multiblock_f32<128, multiblock_cols><<>>( - q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps); - } else if (use_fp16) { - delta_net_fp16_optimized<128><<>>( - q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps); - } else { - // Baseline pre-Blackwell path - delta_net_recurrent_f32<128><<>>( - q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps); - } -#else - // HIP path: keep baseline recurrent implementation - delta_net_recurrent_f32<128><<>>( - q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps); -#endif // !defined(GGML_USE_HIP) + delta_net_recurrent_f32<128, threads_per_block><<>>( + q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps); } else { delta_net_recurrent_generic_f32<<>>( q, k, v, g, beta, state_in, dst, head_dim, n_tokens, n_heads, n_seqs, output_offset, eps); } - // Check for errors (but don't sync during graph capture) CUDA_CHECK(cudaGetLastError()); -#ifdef GGML_CUDA_DEBUG_SYNC - // Only sync when not capturing CUDA graphs - cudaStreamCaptureStatus capture_status; - CUDA_CHECK(cudaStreamIsCapturing(stream, &capture_status)); - if (capture_status == cudaStreamCaptureStatusNone) { - CUDA_CHECK(cudaDeviceSynchronize()); - } -#endif } void ggml_cuda_op_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/src/llama-delta-net.cpp b/src/llama-delta-net.cpp index a0949d5a..891a949a 100644 --- a/src/llama-delta-net.cpp +++ b/src/llama-delta-net.cpp @@ -401,11 +401,18 @@ std::pair delta_net::build_fused_delta_net(ggml_co cb(g, "g_in", il); cb(state,"state_in", il); - q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs); - k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs); - v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); - beta = ggml_cont_4d(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3), 1, n_tokens, H_k, n_seqs); + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + k = ggml_permute(ctx0, k, 0, 2, 1, 3); + v = ggml_permute(ctx0, v, 0, 2, 1, 3); + g = ggml_permute(ctx0, g, 2, 0, 3, 1); + beta = ggml_permute(ctx0, beta, 2, 0, 1, 3); + if (n_seqs > 1) { + q = ggml_cont_4d(ctx0, q, S_k, n_tokens, H_k, n_seqs); + k = ggml_cont_4d(ctx0, k, S_k, n_tokens, H_k, n_seqs); + v = ggml_cont_4d(ctx0, v, S_v, n_tokens, H_v, n_seqs); + g = ggml_cont_4d(ctx0, g, n_tokens, 1, H_k, n_seqs); + beta = ggml_cont_4d(ctx0, beta, 1, n_tokens, H_k, n_seqs); + } ggml_tensor * state_flat = ggml_reshape_4d(ctx0, state, S_v, S_v * H_v, 1, n_seqs); if (!ggml_is_contiguous(state_flat)) {