mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 19:31:48 +00:00
More CUDA fused delta net optimizations
This commit is contained in:
@@ -42,11 +42,11 @@ __global__ void delta_net_recurrent_f32(
|
|||||||
const int64_t output_offset, // offset where state starts in output
|
const int64_t output_offset, // offset where state starts in output
|
||||||
const float eps) {
|
const float eps) {
|
||||||
const int batch_idx = blockIdx.x / n_heads;
|
const int batch_idx = blockIdx.x / n_heads;
|
||||||
const int head_idx = blockIdx.x % n_heads;
|
const int head_idx = blockIdx.x % n_heads;
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
const int warp_id = tid / WARP_SIZE; // 0-7 for 256 threads
|
const int warp_id = tid / WARP_SIZE; // 0-7 for 256 threads
|
||||||
const int lane_id = tid % WARP_SIZE; // 0-31
|
const int lane_id = tid % WARP_SIZE; // 0-31
|
||||||
constexpr int NUM_WARPS = block_size/WARP_SIZE; // 256 / 32
|
constexpr int NUM_WARPS = block_size/WARP_SIZE;
|
||||||
|
|
||||||
// Strides for input tensors (column-major)
|
// Strides for input tensors (column-major)
|
||||||
// Q/K/V: [HEAD_DIM, n_tokens, n_heads, n_seqs]
|
// Q/K/V: [HEAD_DIM, n_tokens, n_heads, n_seqs]
|
||||||
@@ -88,8 +88,7 @@ __global__ void delta_net_recurrent_f32(
|
|||||||
float * sVBeta = sKBeta + HEAD_DIM; // HEAD_DIM (v * sigmoid(beta))
|
float * sVBeta = sKBeta + HEAD_DIM; // HEAD_DIM (v * sigmoid(beta))
|
||||||
float * sOut = sVBeta + HEAD_DIM; // HEAD_DIM
|
float * sOut = sVBeta + HEAD_DIM; // HEAD_DIM
|
||||||
float * sKCumdecay = sOut + HEAD_DIM; // HEAD_DIM (k * sigmoid(beta) * exp(g))
|
float * sKCumdecay = sOut + HEAD_DIM; // HEAD_DIM (k * sigmoid(beta) * exp(g))
|
||||||
float * sVPrime = sKCumdecay + HEAD_DIM; // HEAD_DIM (state @ k_cumdecay)
|
float * sVNew = sKCumdecay + HEAD_DIM; // HEAD_DIM (v_beta - v_prime)
|
||||||
float * sVNew = sVPrime + HEAD_DIM; // HEAD_DIM (v_beta - v_prime)
|
|
||||||
|
|
||||||
const float scale = rsqrtf((float)HEAD_DIM);
|
const float scale = rsqrtf((float)HEAD_DIM);
|
||||||
|
|
||||||
@@ -123,52 +122,41 @@ __global__ void delta_net_recurrent_f32(
|
|||||||
float beta_val = sigmoid_f(beta_ptr[t]);
|
float beta_val = sigmoid_f(beta_ptr[t]);
|
||||||
float decay = expf(fminf(g_ptr[t], 50.0f));
|
float decay = expf(fminf(g_ptr[t], 50.0f));
|
||||||
|
|
||||||
|
float sum = 0;
|
||||||
for (int i = tid; i < HEAD_DIM; i += blockDim.x) {
|
for (int i = tid; i < HEAD_DIM; i += blockDim.x) {
|
||||||
sQ[i] = sQ[i] * q_norm * scale;
|
sQ[i] = sQ[i] * q_norm * scale;
|
||||||
sK[i] = sK[i] * k_norm;
|
sK[i] = sK[i] * k_norm;
|
||||||
sKBeta[i] = sK[i];
|
sKBeta[i] = sK[i];
|
||||||
sVBeta[i] = sV[i] * beta_val;
|
sVBeta[i] = sV[i] * beta_val;
|
||||||
sKCumdecay[i] = sK[i] * beta_val * decay;
|
sKCumdecay[i] = sK[i] * beta_val * decay;
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) {
|
|
||||||
float sum = 0.0f;
|
|
||||||
#pragma unroll
|
|
||||||
for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) {
|
|
||||||
sum += state_dst[row_out + col * HEAD_DIM] * sKCumdecay[col];
|
|
||||||
}
|
|
||||||
sum = warp_reduce_sum(sum);
|
|
||||||
if (lane_id == 0) {
|
|
||||||
sVPrime[row_out] = sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
float sum = 0;
|
|
||||||
for (int i = tid; i < HEAD_DIM; i += block_size) {
|
|
||||||
sVNew[i] = sVBeta[i] - sVPrime[i];
|
|
||||||
sum += sK[i] * sQ[i];
|
sum += sK[i] * sQ[i];
|
||||||
}
|
}
|
||||||
float attn_score = reduce_sum<block_size>(sum, sum_helper);
|
float attn_score = reduce_sum<block_size>(sum, sum_helper);
|
||||||
|
//__syncthreads();
|
||||||
|
|
||||||
for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) {
|
for (int row_out = warp_id; row_out < HEAD_DIM; row_out += NUM_WARPS) {
|
||||||
float sum = 0.0f;
|
float sum1 = 0.0f;
|
||||||
|
float sum2 = 0.0f;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) {
|
for (int col = lane_id; col < HEAD_DIM; col += WARP_SIZE) {
|
||||||
float state_val = state_dst[row_out + col * HEAD_DIM];
|
float sval = state_dst[row_out + col * HEAD_DIM];
|
||||||
sum += sQ[col] * decay * state_val;
|
sum1 += sval * sKCumdecay[col];
|
||||||
|
sum2 += sval * sQ[col];
|
||||||
}
|
}
|
||||||
sum = warp_reduce_sum(sum);
|
sum1 = warp_reduce_sum(sum1);
|
||||||
|
sum2 = warp_reduce_sum(sum2);
|
||||||
if (lane_id == 0) {
|
if (lane_id == 0) {
|
||||||
|
sVNew[row_out] = sVBeta[row_out] - sum1;
|
||||||
float v_attn = sVNew[row_out] * attn_score;
|
float v_attn = sVNew[row_out] * attn_score;
|
||||||
sOut[row_out] = sum + v_attn;
|
//sOut[row_out] = sum2 * decay + v_attn;
|
||||||
|
out_base[t * out_token_stride + row_out] = sum2 * decay + v_attn;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
for (int out_dim = tid; out_dim < HEAD_DIM; out_dim += block_size) {
|
for (int out_dim = warp_id; out_dim < HEAD_DIM; out_dim += NUM_WARPS) {
|
||||||
for (int row = 0; row < HEAD_DIM; row++) {
|
#pragma unroll
|
||||||
|
for (int row = lane_id; row < HEAD_DIM; row += WARP_SIZE) {
|
||||||
float state_val = state_dst[row + out_dim * HEAD_DIM];
|
float state_val = state_dst[row + out_dim * HEAD_DIM];
|
||||||
float safe_decay = decay;
|
float safe_decay = decay;
|
||||||
if (isnan(safe_decay) || isinf(safe_decay)) {
|
if (isnan(safe_decay) || isinf(safe_decay)) {
|
||||||
@@ -179,12 +167,10 @@ __global__ void delta_net_recurrent_f32(
|
|||||||
state_dst[row + out_dim * HEAD_DIM] = new_state_val;
|
state_dst[row + out_dim * HEAD_DIM] = new_state_val;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
__syncthreads();
|
if (t < n_tokens - 1) {
|
||||||
|
__syncthreads();
|
||||||
for (int i = tid; i < HEAD_DIM; i += block_size) {
|
|
||||||
out_base[t * out_token_stride + i] = sOut[i];
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -420,7 +406,7 @@ static void delta_net_f32_cuda(
|
|||||||
|
|
||||||
// One block per (batch, head) pair
|
// One block per (batch, head) pair
|
||||||
const int num_blocks = n_seqs * n_heads;
|
const int num_blocks = n_seqs * n_heads;
|
||||||
constexpr int threads_per_block = 256;
|
constexpr int threads_per_block = 512; //256;
|
||||||
|
|
||||||
// Shared memory: 9 * head_dim (for Q, K, V, KBeta, VBeta, Out, KCumdecay, VPrime, VNew)
|
// Shared memory: 9 * head_dim (for Q, K, V, KBeta, VBeta, Out, KCumdecay, VPrime, VNew)
|
||||||
// Plus 6 floats for Norm[2], g_val, beta_val, decay, attn_score
|
// Plus 6 floats for Norm[2], g_val, beta_val, decay, attn_score
|
||||||
@@ -431,6 +417,7 @@ static void delta_net_f32_cuda(
|
|||||||
delta_net_recurrent_f32<64, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
|
delta_net_recurrent_f32<64, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
|
||||||
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
|
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
|
||||||
} else if (head_dim == 128) {
|
} else if (head_dim == 128) {
|
||||||
|
GGML_ASSERT(num_blocks % 8 == 0);
|
||||||
delta_net_recurrent_f32<128, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
|
delta_net_recurrent_f32<128, threads_per_block><<<num_blocks, threads_per_block, smem_size, stream>>>(
|
||||||
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
|
q, k, v, g, beta, state_in, dst, n_heads, n_tokens, n_seqs, output_offset, eps);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
Reference in New Issue
Block a user