More CUDA fused delta net optimizations

This commit is contained in:
Kawrakow
2026-02-24 10:00:03 +00:00
parent fecdcd5aa1
commit 7af6892dcf

View File

@@ -42,11 +42,11 @@ __global__ void delta_net_recurrent_f32(
const int64_t output_offset, // offset where state starts in output
const float eps) {
const int batch_idx = blockIdx.x / n_heads;
const int head_idx = blockIdx.x % n_heads;
const int head_idx = blockIdx.x % n_heads;
const int tid = threadIdx.x;
const int warp_id = tid / WARP_SIZE; // 0-7 for 256 threads
const int lane_id = tid % WARP_SIZE; // 0-31
constexpr int NUM_WARPS = block_size/WARP_SIZE; // 256 / 32
constexpr int NUM_WARPS = block_size/WARP_SIZE;
// Strides for input tensors (column-major)
// Q/K/V: [HEAD_DIM, n_tokens, n_heads, n_seqs]
@@ -88,8 +88,7 @@ __global__ void delta_net_recurrent_f32(
float * sVBeta = sKBeta + HEAD_DIM; // HEAD_DIM (v * sigmoid(beta))
float * sOut = sVBeta + HEAD_DIM; // HEAD_DIM
float * sKCumdecay = sOut + HEAD_DIM; // HEAD_DIM (k * sigmoid(beta) * exp(g))
float * sVPrime = sKCumdecay + HEAD_DIM; // HEAD_DIM (state @ k_cumdecay)
float * sVNew = sVPrime + HEAD_DIM; // HEAD_DIM (v_beta - v_prime)
float * sVNew = sKCumdecay + HEAD_DIM; // HEAD_DIM (v_beta - v_prime)
const float scale = rsqrtf((float)HEAD_DIM);
@@ -123,52 +122,41 @@ __global__ void delta_net_recurrent_f32(
float beta_val = sigmoid_f(beta_ptr[t]);
float decay = expf(fminf(g_ptr[t], 50.0f));
float sum = 0;
for (int i = tid; i < HEAD_DIM; i += blockDim.x) {
sQ[i] = sQ[i] * q_norm * scale;
sK[i] = sK[i] * k_norm;
sKBeta[i] = sK[i];
sVBeta[i] = sV[i] * beta_val;
sKCumdecay[i] = sK[i] * beta_val * decay;
}
__syncthreads();
for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) {
float sum = 0.0f;
#pragma unroll
for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) {
sum += state_dst[row_out + col * HEAD_DIM] * sKCumdecay[col];
}
sum = warp_reduce_sum(sum);
if (lane_id == 0) {
sVPrime[row_out] = sum;
}
}
__syncthreads();
float sum = 0;
for (int i = tid; i < HEAD_DIM; i += block_size) {
sVNew[i] = sVBeta[i] - sVPrime[i];
sum += sK[i] * sQ[i];
}
float attn_score = reduce_sum<block_size>(sum, sum_helper);
//__syncthreads();
for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) {
float sum = 0.0f;
float sum1 = 0.0f;
float sum2 = 0.0f;
#pragma unroll
for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) {
float state_val = state_dst[row_out + col * HEAD_DIM];
sum += sQ[col] * decay * state_val;
float sval = state_dst[row_out + col * HEAD_DIM];
sum1 += sval * sKCumdecay[col];
sum2 += sval * sQ[col];
}
sum = warp_reduce_sum(sum);
sum1 = warp_reduce_sum(sum1);
sum2 = warp_reduce_sum(sum2);
if (lane_id == 0) {
sVNew[row_out] = sVBeta[row_out] - sum1;
float v_attn = sVNew[row_out] * attn_score;
sOut[row_out] = sum + v_attn;
//sOut[row_out] = sum2 * decay + v_attn;
out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn;
}
}
__syncthreads();
for (int out_dim = tid; out_dim < HEAD_DIM; out_dim += block_size) {
for (int row = 0; row < HEAD_DIM; row++) {
for (int out_dim = warp_id; out_dim < HEAD_DIM; out_dim += NUM_WARPS) {
#pragma unroll
for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) {
float state_val = state_dst[row + out_dim * HEAD_DIM];
float safe_decay = decay;
if (isnan(safe_decay) || isinf(safe_decay)) {
@@ -179,12 +167,10 @@ __global__ void delta_net_recurrent_f32(
state_dst[row + out_dim * HEAD_DIM] = new_state_val;
}
}
__syncthreads();
for (int i = tid; i < HEAD_DIM; i += block_size) {
out_base[t * out_token_stride + i] = sOut[i];
if (t < n_tokens - 1) {
__syncthreads();
}
__syncthreads();
}
}
@@ -420,7 +406,7 @@ static void delta_net_f32_cuda(
// One block per (batch, head) pair
const int num_blocks = n_seqs * n_heads;
constexpr int threads_per_block = 256;
constexpr int threads_per_block = 512; //256;
// Shared memory: 9 * head_dim (for Q, K, V, KBeta, VBeta, Out, KCumdecay, VPrime, VNew)
// Plus 6 floats for Norm[2], g_val, beta_val, decay, attn_score
@@ -431,6 +417,7 @@ static void delta_net_f32_cuda(
delta_net_recurrent_f32<64, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
} else if (head_dim == 128) {
GGML_ASSERT(num_blocks % 8 == 0);
delta_net_recurrent_f32<128, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
} else {