This commit is contained in:
Kawrakow
2026-02-27 09:37:22 +00:00
parent e217c36475
commit 745dee7d4e

View File

@@ -91,17 +91,15 @@ __global__ void delta_net_recurrent_f32_a(
constexpr int num_warps = block_size/WARP_SIZE;
const int row = tid % WARP_SIZE;
const int col_idx_0 = tid / WARP_SIZE;
const int row_out = row + sub_idx * WARP_SIZE;
// Keep the state in registers, copy the final state to its destination at the end
float state_local[HEAD_DIM/num_warps];
for (int i = 0; i < HEAD_DIM/num_warps; ++i) {
int col = num_warps*i + col_idx_0;
state_local[i] = state_src[col*HEAD_DIM + row + sub_idx * WARP_SIZE];
state_local[i] = state_src[col*HEAD_DIM + row_out];
}
//for (int col = col_idx_0; col < HEAD_DIM; col += num_warps) {
// state_dst[col*HEAD_DIM + row + sub_idx * WARP_SIZE] = state_src[col*HEAD_DIM + row + sub_idx * WARP_SIZE];
//}
constexpr int WARP_SIZE_S = WARP_SIZE + 1;
constexpr int num_stored_rows = block_size/WARP_SIZE;
__shared__ float all_sum[2*WARP_SIZE_S*num_stored_rows];
@@ -128,11 +126,6 @@ __global__ void delta_net_recurrent_f32_a(
sum1 += state_local[i] * sK[col];
sum2 += state_local[i] * sQ[col];
}
//for (int col = col_idx_0; col < HEAD_DIM; col += num_warps) {
// float sval = state_dst[row + sub_idx * WARP_SIZE + col * HEAD_DIM];
// sum1 += sval * sK[col];
// sum2 += sval * sQ[col];
//}
all_sum1[col_idx_0*WARP_SIZE_S + row] = sum1;
all_sum2[col_idx_0*WARP_SIZE_S + row] = sum2;
@@ -144,9 +137,9 @@ __global__ void delta_net_recurrent_f32_a(
sum1 += all_sum1[i*WARP_SIZE_S + row];
sum2 += all_sum2[i*WARP_SIZE_S + row];
}
float sv_new = beta_val * (v_ptr[t * qkv_stride_token + row + sub_idx*WARP_SIZE] - sum1 * decay);
float sv_new = beta_val * (v_ptr[t * qkv_stride_token + row_out] - sum1 * decay);
if (col_idx_0 == 0) {
out_base[t * out_token_stride + row + sub_idx*WARP_SIZE] = sum2 * decay + sv_new * attn_score;
out_base[t * out_token_stride + row_out] = sum2 * decay + sv_new * attn_score;
}
for (int i = 0; i < HEAD_DIM/num_warps; ++i) {
@@ -156,16 +149,11 @@ __global__ void delta_net_recurrent_f32_a(
state_local[i] = new_state_val;
}
//for (int col = col_idx_0; col < HEAD_DIM; col += num_warps) {
// float state_val = state_dst[row + sub_idx*WARP_SIZE + col * HEAD_DIM];
// float new_state_val = decay * state_val + sv_new * sK[col];
// new_state_val = fminf(fmaxf(new_state_val, -1e6f), 1e6f);
// state_dst[row + sub_idx*WARP_SIZE + col * HEAD_DIM] = new_state_val;
//}
}
// Copy the final state to its destination
for (int i = 0; i < HEAD_DIM/num_warps; ++i) {
int col = num_warps*i + col_idx_0;
state_dst[col*HEAD_DIM + row + sub_idx * WARP_SIZE] = state_local[i];
state_dst[col*HEAD_DIM + row_out] = state_local[i];
}
}