mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-24 08:29:29 +00:00
cuda: add guarded multi-seq fast path for ssm_conv
This commit is contained in:
@@ -164,6 +164,113 @@ static __global__ void ssm_conv_init_states_f32(
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void ssm_conv_validate_unique_seq_map(
|
||||
const int32_t * src3,
|
||||
int32_t * seq_ids,
|
||||
int32_t * seq_seen,
|
||||
int32_t * fast_path_ok,
|
||||
int n_t,
|
||||
int n_kv,
|
||||
int src3_nb1) {
|
||||
const int t = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (t >= n_t) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t * sq = src3 + (size_t) t * src3_nb1;
|
||||
const int32_t seq0 = sq[0];
|
||||
if (seq0 < 0 || seq0 >= n_kv) {
|
||||
atomicExch(fast_path_ok, 0);
|
||||
return;
|
||||
}
|
||||
|
||||
// Fast path supports one sequence per token (no copy-to-multiple-sequences routing).
|
||||
if (n_kv > 1) {
|
||||
const int32_t seq1 = sq[1];
|
||||
if (seq1 >= 0 && seq1 < n_kv) {
|
||||
atomicExch(fast_path_ok, 0);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
seq_ids[t] = seq0;
|
||||
if (atomicAdd(seq_seen + seq0, 1) != 0) {
|
||||
// Sequence is updated by multiple tokens in the same batch => recurrent dependency across t.
|
||||
atomicExch(fast_path_ok, 0);
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void ssm_conv_multi_seq_unique_f32_kernel(
|
||||
const float * src0,
|
||||
const float * src1,
|
||||
const float * src2,
|
||||
const int32_t * seq_ids,
|
||||
float * dst_x,
|
||||
float * dst_state,
|
||||
int nc,
|
||||
int nr,
|
||||
int n_t,
|
||||
int src1_nb1) {
|
||||
const int row = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int t = blockIdx.y;
|
||||
|
||||
if (row >= nr || t >= n_t) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int seq = seq_ids[t];
|
||||
const float * src_state_row = src0 + (size_t) seq * nr * (nc - 1) + (size_t) row * (nc - 1);
|
||||
float * state_row = dst_state + (size_t) seq * nr * nc + (size_t) row * nc;
|
||||
const float * c_row = src2 + (size_t) row * nc;
|
||||
|
||||
float sumf = 0.0f;
|
||||
for (int i0 = 0; i0 < nc - 1; ++i0) {
|
||||
const float v = src_state_row[i0];
|
||||
state_row[i0] = v;
|
||||
sumf += v * c_row[i0];
|
||||
}
|
||||
|
||||
const float x = src1[row + (size_t) t * src1_nb1];
|
||||
state_row[nc - 1] = x;
|
||||
sumf += x * c_row[nc - 1];
|
||||
dst_x[row + (size_t) t * nr] = sumf;
|
||||
}
|
||||
|
||||
static __global__ void ssm_conv_multi_seq_unique_f32_kernel_nc4(
|
||||
const float * src0,
|
||||
const float * src1,
|
||||
const float * src2,
|
||||
const int32_t * seq_ids,
|
||||
float * dst_x,
|
||||
float * dst_state,
|
||||
int nr,
|
||||
int n_t,
|
||||
int src1_nb1) {
|
||||
const int row = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int t = blockIdx.y;
|
||||
|
||||
if (row >= nr || t >= n_t) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int seq = seq_ids[t];
|
||||
const float * src_state_row = src0 + (size_t) seq * nr * 3 + (size_t) row * 3;
|
||||
float * state_row = dst_state + (size_t) seq * nr * 4 + (size_t) row * 4;
|
||||
const float * c_row = src2 + (size_t) row * 4;
|
||||
|
||||
const float s0 = src_state_row[0];
|
||||
const float s1 = src_state_row[1];
|
||||
const float s2 = src_state_row[2];
|
||||
const float x = src1[row + (size_t) t * src1_nb1];
|
||||
|
||||
state_row[0] = s0;
|
||||
state_row[1] = s1;
|
||||
state_row[2] = s2;
|
||||
state_row[3] = x;
|
||||
|
||||
dst_x[row + (size_t) t * nr] = s0 * c_row[0] + s1 * c_row[1] + s2 * c_row[2] + x * c_row[3];
|
||||
}
|
||||
|
||||
static __global__ void ssm_conv_f32_kernel(
|
||||
const float * src0,
|
||||
const float * src1,
|
||||
@@ -387,6 +494,57 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
dst_state,
|
||||
nc, nr, n_kv);
|
||||
}
|
||||
|
||||
// Fast path for multi-sequence decode-like batches:
|
||||
// one token per unique sequence, no copy-to-multiple-sequences routing.
|
||||
ggml_cuda_pool_alloc<int32_t> seq_ids(ctx.pool(), n_t);
|
||||
ggml_cuda_pool_alloc<int32_t> seq_seen(ctx.pool(), n_kv);
|
||||
ggml_cuda_pool_alloc<int32_t> fast_path_ok_d(ctx.pool(), 1);
|
||||
int32_t fast_path_ok = 1;
|
||||
|
||||
CUDA_CHECK(cudaMemsetAsync(seq_seen.get(), 0, n_kv * sizeof(int32_t), ctx.stream()));
|
||||
CUDA_CHECK(cudaMemcpyAsync(fast_path_ok_d.get(), &fast_path_ok, sizeof(int32_t), cudaMemcpyHostToDevice, ctx.stream()));
|
||||
|
||||
constexpr int seq_map_block_size = 256;
|
||||
const dim3 seq_map_grid((n_t + seq_map_block_size - 1) / seq_map_block_size, 1, 1);
|
||||
ssm_conv_validate_unique_seq_map<<<seq_map_grid, seq_map_block_size, 0, ctx.stream()>>>(
|
||||
(const int32_t *) src3->data,
|
||||
seq_ids.get(),
|
||||
seq_seen.get(),
|
||||
fast_path_ok_d.get(),
|
||||
n_t,
|
||||
n_kv,
|
||||
src3->nb[1] / sizeof(int32_t));
|
||||
|
||||
CUDA_CHECK(cudaMemcpyAsync(&fast_path_ok, fast_path_ok_d.get(), sizeof(int32_t), cudaMemcpyDeviceToHost, ctx.stream()));
|
||||
CUDA_CHECK(cudaStreamSynchronize(ctx.stream()));
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
if (fast_path_ok) {
|
||||
const dim3 token_grid(row_grid.x, n_t, 1);
|
||||
if (nc == 4) {
|
||||
ssm_conv_multi_seq_unique_f32_kernel_nc4<<<token_grid, block_dims, 0, ctx.stream()>>>(
|
||||
(const float *) src0->data,
|
||||
(const float *) src1->data,
|
||||
(const float *) src2->data,
|
||||
seq_ids.get(),
|
||||
dst_x,
|
||||
dst_state,
|
||||
nr, n_t,
|
||||
src1->nb[1] / sizeof(float));
|
||||
} else {
|
||||
ssm_conv_multi_seq_unique_f32_kernel<<<token_grid, block_dims, 0, ctx.stream()>>>(
|
||||
(const float *) src0->data,
|
||||
(const float *) src1->data,
|
||||
(const float *) src2->data,
|
||||
seq_ids.get(),
|
||||
dst_x,
|
||||
dst_state,
|
||||
nc, nr, n_t,
|
||||
src1->nb[1] / sizeof(float));
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (nc == 4) {
|
||||
|
||||
Reference in New Issue
Block a user