mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-01 01:24:08 +00:00
Cleanup
This commit is contained in:
@@ -91,17 +91,15 @@ __global__ void delta_net_recurrent_f32_a(
|
||||
constexpr int num_warps = block_size/WARP_SIZE;
|
||||
const int row = tid % WARP_SIZE;
|
||||
const int col_idx_0 = tid / WARP_SIZE;
|
||||
const int row_out = row + sub_idx * WARP_SIZE;
|
||||
|
||||
// Keep the state in registers, copy the final state to its destination at the end
|
||||
float state_local[HEAD_DIM/num_warps];
|
||||
for (int i = 0; i < HEAD_DIM/num_warps; ++i) {
|
||||
int col = num_warps*i + col_idx_0;
|
||||
state_local[i] = state_src[col*HEAD_DIM + row + sub_idx * WARP_SIZE];
|
||||
state_local[i] = state_src[col*HEAD_DIM + row_out];
|
||||
}
|
||||
|
||||
//for (int col = col_idx_0; col < HEAD_DIM; col += num_warps) {
|
||||
// state_dst[col*HEAD_DIM + row + sub_idx * WARP_SIZE] = state_src[col*HEAD_DIM + row + sub_idx * WARP_SIZE];
|
||||
//}
|
||||
|
||||
constexpr int WARP_SIZE_S = WARP_SIZE + 1;
|
||||
constexpr int num_stored_rows = block_size/WARP_SIZE;
|
||||
__shared__ float all_sum[2*WARP_SIZE_S*num_stored_rows];
|
||||
@@ -128,11 +126,6 @@ __global__ void delta_net_recurrent_f32_a(
|
||||
sum1 += state_local[i] * sK[col];
|
||||
sum2 += state_local[i] * sQ[col];
|
||||
}
|
||||
//for (int col = col_idx_0; col < HEAD_DIM; col += num_warps) {
|
||||
// float sval = state_dst[row + sub_idx * WARP_SIZE + col * HEAD_DIM];
|
||||
// sum1 += sval * sK[col];
|
||||
// sum2 += sval * sQ[col];
|
||||
//}
|
||||
all_sum1[col_idx_0*WARP_SIZE_S + row] = sum1;
|
||||
all_sum2[col_idx_0*WARP_SIZE_S + row] = sum2;
|
||||
|
||||
@@ -144,9 +137,9 @@ __global__ void delta_net_recurrent_f32_a(
|
||||
sum1 += all_sum1[i*WARP_SIZE_S + row];
|
||||
sum2 += all_sum2[i*WARP_SIZE_S + row];
|
||||
}
|
||||
float sv_new = beta_val * (v_ptr[t * qkv_stride_token + row + sub_idx*WARP_SIZE] - sum1 * decay);
|
||||
float sv_new = beta_val * (v_ptr[t * qkv_stride_token + row_out] - sum1 * decay);
|
||||
if (col_idx_0 == 0) {
|
||||
out_base[t * out_token_stride + row + sub_idx*WARP_SIZE] = sum2 * decay + sv_new * attn_score;
|
||||
out_base[t * out_token_stride + row_out] = sum2 * decay + sv_new * attn_score;
|
||||
}
|
||||
|
||||
for (int i = 0; i < HEAD_DIM/num_warps; ++i) {
|
||||
@@ -156,16 +149,11 @@ __global__ void delta_net_recurrent_f32_a(
|
||||
state_local[i] = new_state_val;
|
||||
}
|
||||
|
||||
//for (int col = col_idx_0; col < HEAD_DIM; col += num_warps) {
|
||||
// float state_val = state_dst[row + sub_idx*WARP_SIZE + col * HEAD_DIM];
|
||||
// float new_state_val = decay * state_val + sv_new * sK[col];
|
||||
// new_state_val = fminf(fmaxf(new_state_val, -1e6f), 1e6f);
|
||||
// state_dst[row + sub_idx*WARP_SIZE + col * HEAD_DIM] = new_state_val;
|
||||
//}
|
||||
}
|
||||
// Copy the final state to its destination
|
||||
for (int i = 0; i < HEAD_DIM/num_warps; ++i) {
|
||||
int col = num_warps*i + col_idx_0;
|
||||
state_dst[col*HEAD_DIM + row + sub_idx * WARP_SIZE] = state_local[i];
|
||||
state_dst[col*HEAD_DIM + row_out] = state_local[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user