From 3cd7e5c9b4e83eb990fa297630c0f4254d49b974 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 10 Aug 2025 20:02:33 +0300 Subject: [PATCH] gpt-oss: add sinks to the attn-vec kernels --- ggml/src/ggml-cuda/fattn-vec-f16.cuh | 34 ++++++++++++++++++++++++++++ ggml/src/ggml-cuda/fattn-vec-f32.cuh | 34 ++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 0c273be1..95dd0e96 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -72,6 +72,7 @@ static __global__ void flash_attn_vec_ext_f16( V += nb22*(blockIdx.y / gqa_ratio); const half * maskh = (const half *) mask + ne11*ic0; + const float * sinksf = (const float *) (sinks); const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); @@ -271,6 +272,39 @@ static __global__ void flash_attn_vec_ext_f16( __syncthreads(); } + if (sinksf) { + const half sink = __float2half(sinksf[blockIdx.y]); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (threadIdx.x == 0) { + kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink); + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + half kqmax_new_j = kqmax_shared[j][threadIdx.y]; + kqmax_new_j = warp_reduce_max(kqmax_new_j); + + const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j); + kqmax[j] = kqmax_new_j; + + const half val = hexp(sink - kqmax[j]); + kqsum[j] = kqsum[j]*KQ_max_scale; + + if (tid == 0) { + kqsum[j] += val; + } + + VKQ[j] *= __half2half2(KQ_max_scale); + } + + __syncthreads(); + } + #pragma unroll for (int j = 0; j < ncols; ++j) { kqsum[j] = warp_reduce_sum(kqsum[j]); diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 9c307502..a97b3737 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -70,6 +70,7 @@ static __global__ void flash_attn_vec_ext_f32( K += nb12*(blockIdx.y / gqa_ratio); V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape const half * maskh = (const half *) mask + ne11*ic0; + const float * sinksf = (const float *) (sinks); const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); @@ -255,6 +256,39 @@ static __global__ void flash_attn_vec_ext_f32( __syncthreads(); } + if (sinksf) { + const float sink = sinksf[blockIdx.y]; + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (threadIdx.x == 0) { + kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink); + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + float kqmax_new_j = kqmax_shared[j][threadIdx.y]; + kqmax_new_j = warp_reduce_max(kqmax_new_j); + + const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j); + kqmax[j] = kqmax_new_j; + + const float val = expf(sink - kqmax[j]); + kqsum[j] = kqsum[j]*KQ_max_scale; + + if (tid == 0) { + kqsum[j] += val; + } + + VKQ[j] *= KQ_max_scale; + } + + __syncthreads(); + } + #pragma unroll for (int j = 0; j < ncols; ++j) { kqsum[j] = warp_reduce_sum(kqsum[j]);