More tweaks

This commit is contained in:
Kawrakow
2026-02-25 13:11:47 +00:00
parent 8af3755f32
commit a8ef7e20e7

View File

@@ -95,28 +95,27 @@ __global__ void delta_net_recurrent_f32(
state_dst[i] = state_src[i];
}
__shared__ float all_sum[2*HEAD_DIM*NUM_WARPS];
constexpr int HEAD_DIM_S = HEAD_DIM + 1;
__shared__ float all_sum[2*HEAD_DIM_S*NUM_WARPS];
auto all_sum1 = all_sum;
auto all_sum2 = all_sum1 + HEAD_DIM*NUM_WARPS;
auto all_sum2 = all_sum1 + HEAD_DIM_S*NUM_WARPS;
// Process each token sequentially
for (int64_t t = 0; t < n_tokens; t++) {
float sum_kq = 0.0f;
for (int i = tid; i < HEAD_DIM; i += block_size) {
sQ[i] = q_ptr[t * qkv_stride_token + i];
sQ[i] = q_ptr[t * qkv_stride_token + i] * scale;
sK[i] = k_ptr[t * qkv_stride_token + i];
sV[i] = v_ptr[t * qkv_stride_token + i];
sum_kq += sK[i] * sQ[i];
}
sum_kq = reduce_sum<block_size>(sum_kq, sum_helper);
float attn_score = reduce_sum<block_size>(sum_kq, sum_helper);
float beta_val = sigmoid_f(beta_ptr[t]);
float decay = expf(fminf(g_ptr[t], 50.0f));
float attn_score = sum_kq * scale;
for (int row_out = lane_id; row_out < HEAD_DIM; row_out += WARP_SIZE) {
float sum1 = 0.0f;
float sum2 = 0.0f;
@@ -126,45 +125,25 @@ __global__ void delta_net_recurrent_f32(
sum1 += sval * sK[col];
sum2 += sval * sQ[col];
}
all_sum1[warp_id*HEAD_DIM + row_out] = sum1;
all_sum2[warp_id*HEAD_DIM + row_out] = sum2;
all_sum1[warp_id*HEAD_DIM_S + row_out] = sum1;
all_sum2[warp_id*HEAD_DIM_S + row_out] = sum2;
}
__syncthreads();
for (int row_out = tid; row_out < HEAD_DIM; row_out += block_size) {
float sum1 = all_sum1[row_out];
float sum2 = all_sum2[row_out];
#pragma unroll
for (int i = 1; i < NUM_WARPS; ++i) {
sum1 += all_sum1[row_out + i*HEAD_DIM];
sum2 += all_sum2[row_out + i*HEAD_DIM];
sum1 += all_sum1[row_out + i*HEAD_DIM_S];
sum2 += all_sum2[row_out + i*HEAD_DIM_S];
}
sum1 *= beta_val * decay;
sum2 *= scale * decay;
sVNew[row_out] = sV[row_out] * beta_val - sum1;
sVNew[row_out] = sV[row_out] * beta_val - sum1 * beta_val * decay;
float v_attn = sVNew[row_out] * attn_score;
out_base[t * out_token_stride + row_out] = sum2 + v_attn;
out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn;
}
__syncthreads();
//for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) {
// float sum1 = 0.0f;
// float sum2 = 0.0f;
// #pragma unroll
// for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) {
// float sval = state_dst[row_out + col * HEAD_DIM];
// sum1 += sval * sK[col];
// sum2 += sval * sQ[col];
// }
// sum1 = warp_reduce_sum(sum1) * beta_val * decay;
// sum2 = warp_reduce_sum(sum2) * scale * decay;
// if (lane_id == 0) {
// sVNew[row_out] = sV[row_out] * beta_val - sum1;
// float v_attn = sVNew[row_out] * attn_score;
// out_base[t * out_token_stride + row_out] = sum2 + v_attn;
// }
//}
//__syncthreads();
for (int out_dim = warp_id; out_dim < HEAD_DIM; out_dim += NUM_WARPS) {
#pragma unroll
for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) {
@@ -174,10 +153,6 @@ __global__ void delta_net_recurrent_f32(
state_dst[row + out_dim * HEAD_DIM] = new_state_val;
}
}
//if (t < n_tokens - 1) {
// __syncthreads();
//}
}
}
@@ -415,10 +390,6 @@ static void delta_net_f32_cuda(
const int num_blocks = n_seqs * n_heads;
constexpr int threads_per_block = 512; //256;
// Shared memory: 9 * head_dim (for Q, K, V, KBeta, VBeta, Out, KCumdecay, VPrime, VNew)
// Plus 6 floats for Norm[2], g_val, beta_val, decay, attn_score
//const size_t smem_size = (9 * head_dim + 6) * sizeof(float);
//const size_t smem_size = (4 * head_dim + 2 * n_tokens) * sizeof(float);
const size_t smem_size = 4 * head_dim * sizeof(float);
// Use templated kernel for common head dimensions, generic for others
@@ -430,6 +401,7 @@ static void delta_net_f32_cuda(
delta_net_recurrent_f32<128, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
} else {
GGML_ASSERT("Unsupported delta net head size");
delta_net_recurrent_generic_f32<<<num_blocks, threads_per_block, smem_size, stream>>>(
q, k, v, g, beta, state_in, dst, head_dim, n_tokens, n_heads, n_seqs, output_offset, eps);
}