gpt-oss: add sinks to the attn-vec kernels

This commit is contained in:
Iwan Kawrakow
2025-08-10 20:02:33 +03:00
parent 2b91a9d299
commit 3cd7e5c9b4
2 changed files with 68 additions and 0 deletions

View File

@@ -72,6 +72,7 @@ static __global__ void flash_attn_vec_ext_f16(
V += nb22*(blockIdx.y / gqa_ratio);
const half * maskh = (const half *) mask + ne11*ic0;
const float * sinksf = (const float *) (sinks);
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
@@ -271,6 +272,39 @@ static __global__ void flash_attn_vec_ext_f16(
__syncthreads();
}
if (sinksf) {
const half sink = __float2half(sinksf[blockIdx.y]);
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (threadIdx.x == 0) {
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
}
}
__syncthreads();
#pragma unroll
for (int j = 0; j < ncols; ++j) {
half kqmax_new_j = kqmax_shared[j][threadIdx.y];
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
kqmax[j] = kqmax_new_j;
const half val = hexp(sink - kqmax[j]);
kqsum[j] = kqsum[j]*KQ_max_scale;
if (tid == 0) {
kqsum[j] += val;
}
VKQ[j] *= __half2half2(KQ_max_scale);
}
__syncthreads();
}
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqsum[j] = warp_reduce_sum(kqsum[j]);

View File

@@ -70,6 +70,7 @@ static __global__ void flash_attn_vec_ext_f32(
K += nb12*(blockIdx.y / gqa_ratio);
V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0;
const float * sinksf = (const float *) (sinks);
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
@@ -255,6 +256,39 @@ static __global__ void flash_attn_vec_ext_f32(
__syncthreads();
}
if (sinksf) {
const float sink = sinksf[blockIdx.y];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (threadIdx.x == 0) {
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
}
}
__syncthreads();
#pragma unroll
for (int j = 0; j < ncols; ++j) {
float kqmax_new_j = kqmax_shared[j][threadIdx.y];
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
kqmax[j] = kqmax_new_j;
const float val = expf(sink - kqmax[j]);
kqsum[j] = kqsum[j]*KQ_max_scale;
if (tid == 0) {
kqsum[j] += val;
}
VKQ[j] *= KQ_max_scale;
}
__syncthreads();
}
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqsum[j] = warp_reduce_sum(kqsum[j]);