Very slightly better fused delta-net (#1330)

This commit is contained in:
Kawrakow
2026-02-27 07:03:09 +01:00
committed by GitHub
parent 62a7dcac5a
commit facc8fdc44

View File

@@ -96,9 +96,10 @@ __global__ void delta_net_recurrent_f32(
}
constexpr int HEAD_DIM_S = HEAD_DIM + 1;
__shared__ float all_sum[2*HEAD_DIM_S*NUM_WARPS];
constexpr int num_stored_rows = block_size >= HEAD_DIM && block_size % HEAD_DIM == 0 ? block_size/HEAD_DIM : NUM_WARPS;
__shared__ float all_sum[2*HEAD_DIM_S*num_stored_rows];
auto all_sum1 = all_sum;
auto all_sum2 = all_sum1 + HEAD_DIM_S*NUM_WARPS;
auto all_sum2 = all_sum1 + HEAD_DIM_S*num_stored_rows;
// Process each token sequentially
for (int64_t t = 0; t < n_tokens; t++) {
@@ -116,39 +117,68 @@ __global__ void delta_net_recurrent_f32(
float beta_val = sigmoid_f(beta_ptr[t]);
float decay = expf(fminf(g_ptr[t], 50.0f));
for (int row_out = lane_id; row_out < HEAD_DIM; row_out += WARP_SIZE) {
float sum1 = 0.0f;
float sum2 = 0.0f;
if constexpr (block_size >= HEAD_DIM && block_size % HEAD_DIM == 0) {
int idx = tid / HEAD_DIM;
int row_out = tid % HEAD_DIM;
float sum1 = 0, sum2 = 0;
#pragma unroll
for (int col = warp_id; col < HEAD_DIM; col += NUM_WARPS) {
for (int col = idx; col < HEAD_DIM; col += block_size/HEAD_DIM) {
float sval = state_dst[row_out + col * HEAD_DIM];
sum1 += sval * sK[col];
sum2 += sval * sQ[col];
}
all_sum1[warp_id*HEAD_DIM_S + row_out] = sum1;
all_sum2[warp_id*HEAD_DIM_S + row_out] = sum2;
}
__syncthreads();
all_sum1[idx*HEAD_DIM_S + row_out] = sum1;
all_sum2[idx*HEAD_DIM_S + row_out] = sum2;
for (int row_out = tid; row_out < HEAD_DIM; row_out += block_size) {
float sum1 = all_sum1[row_out];
float sum2 = all_sum2[row_out];
#pragma unroll
for (int i = 1; i < NUM_WARPS; ++i) {
sum1 += all_sum1[row_out + i*HEAD_DIM_S];
sum2 += all_sum2[row_out + i*HEAD_DIM_S];
__syncthreads();
if (idx == 0) {
#pragma unroll
for (int i = 1; i < block_size/HEAD_DIM; ++i) {
sum1 += all_sum1[i*HEAD_DIM_S + row_out];
sum2 += all_sum2[i*HEAD_DIM_S + row_out];
}
sVNew[row_out] = sV[row_out] * beta_val - sum1 * beta_val * decay;
float v_attn = sVNew[row_out] * attn_score;
out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn;
}
sVNew[row_out] = sV[row_out] * beta_val - sum1 * beta_val * decay;
float v_attn = sVNew[row_out] * attn_score;
out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn;
__syncthreads();
} else {
for (int row_out = lane_id; row_out < HEAD_DIM; row_out += WARP_SIZE) {
float sum1 = 0.0f;
float sum2 = 0.0f;
#pragma unroll
for (int col = warp_id; col < HEAD_DIM; col += NUM_WARPS) {
float sval = state_dst[row_out + col * HEAD_DIM];
sum1 += sval * sK[col];
sum2 += sval * sQ[col];
}
all_sum1[warp_id*HEAD_DIM_S + row_out] = sum1;
all_sum2[warp_id*HEAD_DIM_S + row_out] = sum2;
}
__syncthreads();
for (int row_out = tid; row_out < HEAD_DIM; row_out += block_size) {
float sum1 = all_sum1[row_out];
float sum2 = all_sum2[row_out];
#pragma unroll
for (int i = 1; i < NUM_WARPS; ++i) {
sum1 += all_sum1[row_out + i*HEAD_DIM_S];
sum2 += all_sum2[row_out + i*HEAD_DIM_S];
}
sVNew[row_out] = sV[row_out] * beta_val - sum1 * beta_val * decay;
float v_attn = sVNew[row_out] * attn_score;
out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn;
}
__syncthreads();
}
__syncthreads();
for (int out_dim = warp_id; out_dim < HEAD_DIM; out_dim += NUM_WARPS) {
float k_col = sK[out_dim];
#pragma unroll
for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) {
float state_val = state_dst[row + out_dim * HEAD_DIM];
float new_state_val = decay * state_val + sVNew[row] * sK[out_dim];
float new_state_val = decay * state_val + sVNew[row] * k_col; //sK[out_dim];
new_state_val = fminf(fmaxf(new_state_val, -1e6f), 1e6f);
state_dst[row + out_dim * HEAD_DIM] = new_state_val;
}