This commit is contained in:
Kawrakow
2026-02-27 13:23:54 +00:00
parent 6ac4335155
commit 3c43fe37fa

View File

@@ -137,6 +137,9 @@ __global__ void delta_net_recurrent_f32(
sum1 += all_sum1[i*WARP_SIZE_S + row];
sum2 += all_sum2[i*WARP_SIZE_S + row];
}
// To be honest, I don't understand why we need this sync. But without it I observe results varying from run to run
__syncthreads();
float sv_new = beta_val * (v_ptr[t * qkv_stride_token + row_out] - sum1 * decay);
if (col_idx_0 == 0) {
out_base[t * out_token_stride + row_out] = sum2 * decay + sv_new * attn_score;