mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-28 17:14:17 +00:00
More CUDA fused delta net optimizations
This commit is contained in:
@@ -42,11 +42,11 @@ __global__ void delta_net_recurrent_f32(
|
||||
const int64_t output_offset, // offset where state starts in output
|
||||
const float eps) {
|
||||
const int batch_idx = blockIdx.x / n_heads;
|
||||
const int head_idx = blockIdx.x % n_heads;
|
||||
const int head_idx = blockIdx.x % n_heads;
|
||||
const int tid = threadIdx.x;
|
||||
const int warp_id = tid / WARP_SIZE; // 0-7 for 256 threads
|
||||
const int lane_id = tid % WARP_SIZE; // 0-31
|
||||
constexpr int NUM_WARPS = block_size/WARP_SIZE; // 256 / 32
|
||||
constexpr int NUM_WARPS = block_size/WARP_SIZE;
|
||||
|
||||
// Strides for input tensors (column-major)
|
||||
// Q/K/V: [HEAD_DIM, n_tokens, n_heads, n_seqs]
|
||||
@@ -88,8 +88,7 @@ __global__ void delta_net_recurrent_f32(
|
||||
float * sVBeta = sKBeta + HEAD_DIM; // HEAD_DIM (v * sigmoid(beta))
|
||||
float * sOut = sVBeta + HEAD_DIM; // HEAD_DIM
|
||||
float * sKCumdecay = sOut + HEAD_DIM; // HEAD_DIM (k * sigmoid(beta) * exp(g))
|
||||
float * sVPrime = sKCumdecay + HEAD_DIM; // HEAD_DIM (state @ k_cumdecay)
|
||||
float * sVNew = sVPrime + HEAD_DIM; // HEAD_DIM (v_beta - v_prime)
|
||||
float * sVNew = sKCumdecay + HEAD_DIM; // HEAD_DIM (v_beta - v_prime)
|
||||
|
||||
const float scale = rsqrtf((float)HEAD_DIM);
|
||||
|
||||
@@ -123,52 +122,41 @@ __global__ void delta_net_recurrent_f32(
|
||||
float beta_val = sigmoid_f(beta_ptr[t]);
|
||||
float decay = expf(fminf(g_ptr[t], 50.0f));
|
||||
|
||||
float sum = 0;
|
||||
for (int i = tid; i < HEAD_DIM; i += blockDim.x) {
|
||||
sQ[i] = sQ[i] * q_norm * scale;
|
||||
sK[i] = sK[i] * k_norm;
|
||||
sKBeta[i] = sK[i];
|
||||
sVBeta[i] = sV[i] * beta_val;
|
||||
sKCumdecay[i] = sK[i] * beta_val * decay;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) {
|
||||
float sum = 0.0f;
|
||||
#pragma unroll
|
||||
for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) {
|
||||
sum += state_dst[row_out + col * HEAD_DIM] * sKCumdecay[col];
|
||||
}
|
||||
sum = warp_reduce_sum(sum);
|
||||
if (lane_id == 0) {
|
||||
sVPrime[row_out] = sum;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float sum = 0;
|
||||
for (int i = tid; i < HEAD_DIM; i += block_size) {
|
||||
sVNew[i] = sVBeta[i] - sVPrime[i];
|
||||
sum += sK[i] * sQ[i];
|
||||
}
|
||||
float attn_score = reduce_sum<block_size>(sum, sum_helper);
|
||||
//__syncthreads();
|
||||
|
||||
for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) {
|
||||
float sum = 0.0f;
|
||||
float sum1 = 0.0f;
|
||||
float sum2 = 0.0f;
|
||||
#pragma unroll
|
||||
for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) {
|
||||
float state_val = state_dst[row_out + col * HEAD_DIM];
|
||||
sum += sQ[col] * decay * state_val;
|
||||
float sval = state_dst[row_out + col * HEAD_DIM];
|
||||
sum1 += sval * sKCumdecay[col];
|
||||
sum2 += sval * sQ[col];
|
||||
}
|
||||
sum = warp_reduce_sum(sum);
|
||||
sum1 = warp_reduce_sum(sum1);
|
||||
sum2 = warp_reduce_sum(sum2);
|
||||
if (lane_id == 0) {
|
||||
sVNew[row_out] = sVBeta[row_out] - sum1;
|
||||
float v_attn = sVNew[row_out] * attn_score;
|
||||
sOut[row_out] = sum + v_attn;
|
||||
//sOut[row_out] = sum2 * decay + v_attn;
|
||||
out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int out_dim = tid; out_dim < HEAD_DIM; out_dim += block_size) {
|
||||
for (int row = 0; row < HEAD_DIM; row++) {
|
||||
for (int out_dim = warp_id; out_dim < HEAD_DIM; out_dim += NUM_WARPS) {
|
||||
#pragma unroll
|
||||
for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) {
|
||||
float state_val = state_dst[row + out_dim * HEAD_DIM];
|
||||
float safe_decay = decay;
|
||||
if (isnan(safe_decay) || isinf(safe_decay)) {
|
||||
@@ -179,12 +167,10 @@ __global__ void delta_net_recurrent_f32(
|
||||
state_dst[row + out_dim * HEAD_DIM] = new_state_val;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int i = tid; i < HEAD_DIM; i += block_size) {
|
||||
out_base[t * out_token_stride + i] = sOut[i];
|
||||
if (t < n_tokens - 1) {
|
||||
__syncthreads();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -420,7 +406,7 @@ static void delta_net_f32_cuda(
|
||||
|
||||
// One block per (batch, head) pair
|
||||
const int num_blocks = n_seqs * n_heads;
|
||||
constexpr int threads_per_block = 256;
|
||||
constexpr int threads_per_block = 512; //256;
|
||||
|
||||
// Shared memory: 9 * head_dim (for Q, K, V, KBeta, VBeta, Out, KCumdecay, VPrime, VNew)
|
||||
// Plus 6 floats for Norm[2], g_val, beta_val, decay, attn_score
|
||||
@@ -431,6 +417,7 @@ static void delta_net_f32_cuda(
|
||||
delta_net_recurrent_f32<64, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
|
||||
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
|
||||
} else if (head_dim == 128) {
|
||||
GGML_ASSERT(num_blocks % 8 == 0);
|
||||
delta_net_recurrent_f32<128, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
|
||||
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user