cuda: add guarded multi-seq fast path for ssm_conv

This commit is contained in:
yurko
2026-02-06 13:52:54 +00:00
parent 89e9ecfa84
commit 236633af99

View File

@@ -164,6 +164,113 @@ static __global__ void ssm_conv_init_states_f32(
}
}
static __global__ void ssm_conv_validate_unique_seq_map(
const int32_t * src3,
int32_t * seq_ids,
int32_t * seq_seen,
int32_t * fast_path_ok,
int n_t,
int n_kv,
int src3_nb1) {
const int t = blockIdx.x * blockDim.x + threadIdx.x;
if (t >= n_t) {
return;
}
const int32_t * sq = src3 + (size_t) t * src3_nb1;
const int32_t seq0 = sq[0];
if (seq0 < 0 || seq0 >= n_kv) {
atomicExch(fast_path_ok, 0);
return;
}
// Fast path supports one sequence per token (no copy-to-multiple-sequences routing).
if (n_kv > 1) {
const int32_t seq1 = sq[1];
if (seq1 >= 0 && seq1 < n_kv) {
atomicExch(fast_path_ok, 0);
return;
}
}
seq_ids[t] = seq0;
if (atomicAdd(seq_seen + seq0, 1) != 0) {
// Sequence is updated by multiple tokens in the same batch => recurrent dependency across t.
atomicExch(fast_path_ok, 0);
}
}
static __global__ void ssm_conv_multi_seq_unique_f32_kernel(
const float * src0,
const float * src1,
const float * src2,
const int32_t * seq_ids,
float * dst_x,
float * dst_state,
int nc,
int nr,
int n_t,
int src1_nb1) {
const int row = blockIdx.x * blockDim.x + threadIdx.x;
const int t = blockIdx.y;
if (row >= nr || t >= n_t) {
return;
}
const int seq = seq_ids[t];
const float * src_state_row = src0 + (size_t) seq * nr * (nc - 1) + (size_t) row * (nc - 1);
float * state_row = dst_state + (size_t) seq * nr * nc + (size_t) row * nc;
const float * c_row = src2 + (size_t) row * nc;
float sumf = 0.0f;
for (int i0 = 0; i0 < nc - 1; ++i0) {
const float v = src_state_row[i0];
state_row[i0] = v;
sumf += v * c_row[i0];
}
const float x = src1[row + (size_t) t * src1_nb1];
state_row[nc - 1] = x;
sumf += x * c_row[nc - 1];
dst_x[row + (size_t) t * nr] = sumf;
}
static __global__ void ssm_conv_multi_seq_unique_f32_kernel_nc4(
const float * src0,
const float * src1,
const float * src2,
const int32_t * seq_ids,
float * dst_x,
float * dst_state,
int nr,
int n_t,
int src1_nb1) {
const int row = blockIdx.x * blockDim.x + threadIdx.x;
const int t = blockIdx.y;
if (row >= nr || t >= n_t) {
return;
}
const int seq = seq_ids[t];
const float * src_state_row = src0 + (size_t) seq * nr * 3 + (size_t) row * 3;
float * state_row = dst_state + (size_t) seq * nr * 4 + (size_t) row * 4;
const float * c_row = src2 + (size_t) row * 4;
const float s0 = src_state_row[0];
const float s1 = src_state_row[1];
const float s2 = src_state_row[2];
const float x = src1[row + (size_t) t * src1_nb1];
state_row[0] = s0;
state_row[1] = s1;
state_row[2] = s2;
state_row[3] = x;
dst_x[row + (size_t) t * nr] = s0 * c_row[0] + s1 * c_row[1] + s2 * c_row[2] + x * c_row[3];
}
static __global__ void ssm_conv_f32_kernel(
const float * src0,
const float * src1,
@@ -387,6 +494,57 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
dst_state,
nc, nr, n_kv);
}
// Fast path for multi-sequence decode-like batches:
// one token per unique sequence, no copy-to-multiple-sequences routing.
ggml_cuda_pool_alloc<int32_t> seq_ids(ctx.pool(), n_t);
ggml_cuda_pool_alloc<int32_t> seq_seen(ctx.pool(), n_kv);
ggml_cuda_pool_alloc<int32_t> fast_path_ok_d(ctx.pool(), 1);
int32_t fast_path_ok = 1;
CUDA_CHECK(cudaMemsetAsync(seq_seen.get(), 0, n_kv * sizeof(int32_t), ctx.stream()));
CUDA_CHECK(cudaMemcpyAsync(fast_path_ok_d.get(), &fast_path_ok, sizeof(int32_t), cudaMemcpyHostToDevice, ctx.stream()));
constexpr int seq_map_block_size = 256;
const dim3 seq_map_grid((n_t + seq_map_block_size - 1) / seq_map_block_size, 1, 1);
ssm_conv_validate_unique_seq_map<<<seq_map_grid, seq_map_block_size, 0, ctx.stream()>>>(
(const int32_t *) src3->data,
seq_ids.get(),
seq_seen.get(),
fast_path_ok_d.get(),
n_t,
n_kv,
src3->nb[1] / sizeof(int32_t));
CUDA_CHECK(cudaMemcpyAsync(&fast_path_ok, fast_path_ok_d.get(), sizeof(int32_t), cudaMemcpyDeviceToHost, ctx.stream()));
CUDA_CHECK(cudaStreamSynchronize(ctx.stream()));
CUDA_CHECK(cudaGetLastError());
if (fast_path_ok) {
const dim3 token_grid(row_grid.x, n_t, 1);
if (nc == 4) {
ssm_conv_multi_seq_unique_f32_kernel_nc4<<<token_grid, block_dims, 0, ctx.stream()>>>(
(const float *) src0->data,
(const float *) src1->data,
(const float *) src2->data,
seq_ids.get(),
dst_x,
dst_state,
nr, n_t,
src1->nb[1] / sizeof(float));
} else {
ssm_conv_multi_seq_unique_f32_kernel<<<token_grid, block_dims, 0, ctx.stream()>>>(
(const float *) src0->data,
(const float *) src1->data,
(const float *) src2->data,
seq_ids.get(),
dst_x,
dst_state,
nc, nr, n_t,
src1->nb[1] / sizeof(float));
}
return;
}
}
if (nc == 4) {